Spaces:
Runtime error
Runtime error
import torch | |
import torchvision | |
from vqgan import VQModel | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import T5EncoderModel, AutoTokenizer | |
transforms = torchvision.transforms.Compose([ | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Resize(256), | |
torchvision.transforms.RandomCrop(256), | |
]) | |
class YOUR_DATASET(Dataset): | |
def __init__(self, dataset_path): | |
pass | |
def get_dataloader(dataset_path, batch_size): | |
dataset = YOUR_DATASET(dataset_path) | |
return DataLoader(dataset, batch_size=batch_size, num_workers=8, pin_memory=True) | |
def load_conditional_models(byt5_model_name, vqgan_path, device): | |
vqgan = VQModel().to(device) | |
vqgan.load_state_dict(torch.load(vqgan_path, map_location=device)['state_dict']) | |
vqgan.eval().requires_grad_(False) | |
byt5 = T5EncoderModel.from_pretrained(byt5_model_name).to(device).eval().requires_grad_(False) | |
byt5_tokenizer = AutoTokenizer.from_pretrained(byt5_model_name) | |
return vqgan, (byt5_tokenizer, byt5) | |
def sample(model, model_inputs, latent_shape, unconditional_inputs=None, steps=12, renoise_steps=11, temperature=(1.0, 0.2), cfg=8.0, t_start=1.0, t_end=0.0, device="cuda"): | |
with torch.inference_mode(): | |
sampled = torch.randint(0, model.num_labels, size=latent_shape, device=device) | |
init_noise = sampled.clone() | |
t_list = torch.linspace(t_start, t_end, steps+1) | |
temperatures = torch.linspace(temperature[0], temperature[1], steps) | |
for i, t in enumerate(t_list[:steps]): | |
t = torch.ones(latent_shape[0], device=device) * t | |
logits = model(sampled, t, **model_inputs) | |
if cfg: | |
logits = logits * cfg + model(sampled, t, **unconditional_inputs) * (1-cfg) | |
scores = logits.div(temperatures[i]).softmax(dim=1) | |
sampled = scores.permute(0, 2, 3, 1).reshape(-1, logits.size(1)) | |
sampled = torch.multinomial(sampled, 1)[:, 0].view(logits.size(0), *logits.shape[2:]) | |
if i < renoise_steps: | |
t_next = torch.ones(latent_shape[0], device=device) * t_list[i+1] | |
sampled = model.add_noise(sampled, t_next, random_x=init_noise)[0] | |
return sampled | |