Spaces:
Sleeping
Sleeping
from typing import Any, Dict, List, Tuple | |
import clip | |
from hydra import compose, initialize | |
from hydra.utils import instantiate | |
from omegaconf import OmegaConf | |
import torch | |
from torchtyping import TensorType | |
from torch.utils.data import DataLoader | |
import torch.nn.functional as F | |
from src.diffuser import Diffuser | |
from src.datasets.multimodal_dataset import MultimodalDataset | |
# ------------------------------------------------------------------------------------- # | |
batch_size, context_length = None, None | |
collate_fn = DataLoader([]).collate_fn | |
# ------------------------------------------------------------------------------------- # | |
def to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: | |
for key, value in batch.items(): | |
if isinstance(value, torch.Tensor): | |
batch[key] = value.to(device) | |
return batch | |
def load_clip_model(version: str, device: str) -> clip.model.CLIP: | |
model, _ = clip.load(version, device=device, jit=False) | |
model.eval() | |
for p in model.parameters(): | |
p.requires_grad = False | |
return model | |
def encode_text( | |
caption_raws: List[str], # batch_size | |
clip_model: clip.model.CLIP, | |
max_token_length: int, | |
device: str, | |
) -> TensorType["batch_size", "context_length"]: | |
if max_token_length is not None: | |
default_context_length = 77 | |
context_length = max_token_length + 2 # start_token + 20 + end_token | |
assert context_length < default_context_length | |
# [bs, context_length] # if n_tokens > context_length -> will truncate | |
texts = clip.tokenize( | |
caption_raws, context_length=context_length, truncate=True | |
) | |
zero_pad = torch.zeros( | |
[texts.shape[0], default_context_length - context_length], | |
dtype=texts.dtype, | |
device=texts.device, | |
) | |
texts = torch.cat([texts, zero_pad], dim=1) | |
else: | |
# [bs, context_length] # if n_tokens > 77 -> will truncate | |
texts = clip.tokenize(caption_raws, truncate=True) | |
# [batch_size, n_ctx, d_model] | |
x = clip_model.token_embedding(texts.to(device)).type(clip_model.dtype) | |
x = x + clip_model.positional_embedding.type(clip_model.dtype) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = clip_model.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = clip_model.ln_final(x).type(clip_model.dtype) | |
# x.shape = [batch_size, n_ctx, transformer.width] | |
# take features from the eot embedding (eot_token is the highest in each sequence) | |
x_tokens = x[torch.arange(x.shape[0]), texts.argmax(dim=-1)].float() | |
x_seq = [x[k, : (m + 1)].float() for k, m in enumerate(texts.argmax(dim=-1))] | |
return x_seq, x_tokens | |
def get_batch( | |
prompt: str, | |
sample_id: str, | |
clip_model: clip.model.CLIP, | |
dataset: MultimodalDataset, | |
seq_feat: bool, | |
device: torch.device, | |
) -> Dict[str, Any]: | |
# Get base batch | |
sample_index = dataset.root_filenames.index(sample_id) | |
raw_batch = dataset[sample_index] | |
batch = collate_fn([to_device(raw_batch, device)]) | |
# Encode text | |
caption_seq, caption_tokens = encode_text([prompt], clip_model, None, device) | |
if seq_feat: | |
caption_feat = caption_seq[0] | |
caption_feat = F.pad(caption_feat, (0, 0, 0, 77 - caption_feat.shape[0])) | |
caption_feat = caption_feat.unsqueeze(0).permute(0, 2, 1) | |
else: | |
caption_feat = caption_tokens | |
# Update batch | |
batch["caption_raw"] = [prompt] | |
batch["caption_feat"] = caption_feat | |
return batch | |
def init( | |
config_name: str, | |
) -> Tuple[Diffuser, clip.model.CLIP, MultimodalDataset, torch.device]: | |
with initialize(version_base="1.3", config_path="../configs"): | |
config = compose(config_name=config_name) | |
OmegaConf.register_new_resolver("eval", eval) | |
# Initialize model | |
device = torch.device(config.compnode.device) | |
diffuser = instantiate(config.diffuser) | |
state_dict = torch.load(config.checkpoint_path, map_location=device)["state_dict"] | |
state_dict["ema.initted"] = diffuser.ema.initted | |
state_dict["ema.step"] = diffuser.ema.step | |
diffuser.load_state_dict(state_dict, strict=False) | |
diffuser.to(device).eval() | |
# Initialize CLIP model | |
clip_model = load_clip_model("ViT-B/32", device) | |
# Initialize dataset | |
config.dataset.char.load_vertices = True | |
config.batch_size = 1 | |
dataset = instantiate(config.dataset) | |
dataset.set_split("demo") | |
diffuser.modalities = list(dataset.modality_datasets.keys()) | |
diffuser.get_matrix = dataset.get_matrix | |
diffuser.v_get_matrix = dataset.get_matrix | |
return diffuser, clip_model, dataset, device | |