# General import os from os.path import join as opj import datetime import torch from einops import rearrange, repeat # Utilities from t2v_enhanced.inference_utils import * from modelscope.outputs import OutputKeys import imageio from PIL import Image import numpy as np import torch.nn.functional as F import torchvision.transforms as transforms from diffusers.utils import load_image transform = transforms.Compose([ transforms.PILToTensor() ]) def ms_short_gen(prompt, ms_model, inference_generator, t=50, device="cuda"): frames = ms_model(prompt, num_inference_steps=t, generator=inference_generator, eta=1.0, height=256, width=256, latents=None).frames frames = torch.stack([torch.from_numpy(frame) for frame in frames]) frames = frames.to(device).to(torch.float32) return rearrange(frames[0], "F W H C -> F C W H") def ad_short_gen(prompt, ad_model, inference_generator, t=25, device="cuda"): frames = ad_model(prompt, negative_prompt="bad quality, worse quality", num_frames=16, num_inference_steps=t, generator=inference_generator, guidance_scale=7.5).frames[0] frames = torch.stack([transform(frame) for frame in frames]) frames = frames.to(device).to(torch.float32) frames = F.interpolate(frames, size=256) frames = frames/255.0 return frames def sdxl_image_gen(prompt, sdxl_model): image = sdxl_model(prompt=prompt).images[0] return image def svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t=25, device="cuda"): if image is None: image = sdxl_image_gen(prompt, sdxl_model) image = image.resize((576, 576)) image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) elif type(image) is str: image = load_image(image) image = resize_and_keep(image) image = center_crop(image) image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) else: image = Image.fromarray(np.uint8(image)) image = resize_and_keep(image) image = center_crop(image) image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) frames = svd_model(image, decode_chunk_size=4, generator=inference_generator).frames[0] frames = torch.stack([transform(frame) for frame in frames]) frames = frames.to(device).to(torch.float32) frames = frames[:16,:,:,224:-224] frames = F.interpolate(frames, size=256) frames = frames/255.0 return frames def stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, result_file_stem, stream_cli, stream_model): trainer = stream_cli.trainer trainer.limit_predict_batches = 1 trainer.predict_cfg = { "predict_dir": stream_cli.config["result_fol"].as_posix(), "result_file_stem": result_file_stem, "prompt": prompt, "video": short_video, "seed": seed, "num_inference_steps": t, "guidance_scale": image_guidance, 'n_autoregressive_generations': n_autoreg_gen, } trainer.predict(model=stream_model, datamodule=stream_cli.datamodule) def video2video(prompt, video, where_to_log, cfg_v2v, model_v2v, square=True): downscale = cfg_v2v['downscale'] upscale_size = cfg_v2v['upscale_size'] pad = cfg_v2v['pad'] now = datetime.datetime.now() now = str(now.time()).replace(":", "_").replace(".", "_") name = prompt[:100].replace(" ", "_") + "_" + now enhanced_video_mp4 = opj(where_to_log, name+"_enhanced.mp4") video_frames = imageio.mimread(video) h, w, _ = video_frames[0].shape # Downscale video, then resize to fit the upscale size video = [Image.fromarray(frame).resize((w//downscale, h//downscale)) for frame in video_frames] video = [resize_to_fit(frame, upscale_size) for frame in video] if pad: video = [pad_to_fit(frame, upscale_size) for frame in video] # video = [np.array(frame) for frame in video] imageio.mimsave(opj(where_to_log, 'temp_'+now+'.mp4'), video, fps=8) p_input = { 'video_path': opj(where_to_log, 'temp_'+now+'.mp4'), 'text': prompt } output_video_path = model_v2v(p_input, output_video=enhanced_video_mp4)[OutputKeys.OUTPUT_VIDEO] # Remove padding video_frames = imageio.mimread(enhanced_video_mp4) video_frames_square = [] for frame in video_frames: frame = frame[:, 280:-280, :] video_frames_square.append(frame) imageio.mimsave(enhanced_video_mp4, video_frames_square) return enhanced_video_mp4