import os import math import argparse from typing import List, Union from tqdm import tqdm from omegaconf import ListConfig import imageio import torch import numpy as np from einops import rearrange import torchvision.transforms as TT from sat.model.base_model import get_model from sat.training.model_io import load_checkpoint from sat import mpu from diffusion_video import SATVideoDiffusionEngine from arguments import get_args from torchvision.transforms.functional import center_crop, resize from torchvision.transforms import InterpolationMode def read_from_cli(): cnt = 0 try: while True: x = input("Please input English text (Ctrl-D quit): ") yield x.strip(), cnt cnt += 1 except EOFError as e: pass def read_from_file(p, rank=0, world_size=1): with open(p, "r") as fin: cnt = -1 for l in fin: cnt += 1 if cnt % world_size != rank: continue yield l.strip(), cnt def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"): batch = {} batch_uc = {} for key in keys: if key == "txt": batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() else: batch[key] = value_dict[key] if T is not None: batch["num_video_frames"] = T for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None): os.makedirs(save_path, exist_ok=True) for i, vid in enumerate(video_batch): gif_frames = [] for frame in vid: frame = rearrange(frame, "c h w -> h w c") frame = (255.0 * frame).cpu().numpy().astype(np.uint8) gif_frames.append(frame) now_save_path = os.path.join(save_path, f"{i:06d}.mp4") with imageio.get_writer(now_save_path, fps=fps) as writer: for frame in gif_frames: writer.append_data(frame) def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: arr = resize( arr, size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], interpolation=InterpolationMode.BICUBIC, ) else: arr = resize( arr, size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], interpolation=InterpolationMode.BICUBIC, ) h, w = arr.shape[2], arr.shape[3] arr = arr.squeeze(0) delta_h = h - image_size[0] delta_w = w - image_size[1] if reshape_mode == "random" or reshape_mode == "none": top = np.random.randint(0, delta_h + 1) left = np.random.randint(0, delta_w + 1) elif reshape_mode == "center": top, left = delta_h // 2, delta_w // 2 else: raise NotImplementedError arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) return arr def sampling_main(args, model_cls): if isinstance(model_cls, type): model = get_model(args, model_cls) else: model = model_cls load_checkpoint(model, args) model.eval() if args.input_type == "cli": data_iter = read_from_cli() elif args.input_type == "txt": rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size() print("rank and world_size", rank, world_size) data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size) else: raise NotImplementedError image_size = [480, 720] sample_func = model.sample T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8 num_samples = [1] force_uc_zero_embeddings = ["txt"] device = model.device with torch.no_grad(): for text, cnt in tqdm(data_iter): # reload model on GPU model.to(device) print("rank:", rank, "start to process", text, cnt) # TODO: broadcast image2video value_dict = { "prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0), } batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples ) for key in batch: if isinstance(batch[key], torch.Tensor): print(key, batch[key].shape) elif isinstance(batch[key], list): print(key, [len(l) for l in batch[key]]) else: print(key, batch[key]) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) for k in c: if not k == "crossattn": c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)) for index in range(args.batch_size): # reload model on GPU model.to(device) samples_z = sample_func( c, uc=uc, batch_size=1, shape=(T, C, H // F, W // F), ) samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() # Unload the model from GPU to save GPU memory model.to("cpu") torch.cuda.empty_cache() first_stage_model = model.first_stage_model first_stage_model = first_stage_model.to(device) latent = 1.0 / model.scale_factor * samples_z # Decode latent serial to save GPU memory recons = [] loop_num = (T - 1) // 2 for i in range(loop_num): if i == 0: start_frame, end_frame = 0, 3 else: start_frame, end_frame = i * 2 + 1, i * 2 + 3 if i == loop_num - 1: clear_fake_cp_cache = True else: clear_fake_cp_cache = False with torch.no_grad(): recon = first_stage_model.decode( latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache ) recons.append(recon) recon = torch.cat(recons, dim=2).to(torch.float32) samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() save_path = os.path.join( args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) ) if mpu.get_model_parallel_rank() == 0: save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) if __name__ == "__main__": if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] py_parser = argparse.ArgumentParser(add_help=False) known, args_list = py_parser.parse_known_args() args = get_args(args_list) args = argparse.Namespace(**vars(args), **vars(known)) del args.deepspeed_config args.model_config.first_stage_config.params.cp_size = 1 args.model_config.network_config.params.transformer_args.model_parallel_size = 1 args.model_config.network_config.params.transformer_args.checkpoint_activations = False args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False sampling_main(args, model_cls=SATVideoDiffusionEngine)