ychenhq commited on
Commit
1b15ca6
1 Parent(s): ec9f430

Upload 8 files

Browse files
scripts/evaluation/__pycache__/funcs.cpython-310.pyc ADDED
Binary file (6.48 kB). View file
 
scripts/evaluation/ddp_wrapper.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import argparse, importlib
3
+ from pytorch_lightning import seed_everything
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ def setup_dist(local_rank):
9
+ if dist.is_initialized():
10
+ return
11
+ torch.cuda.set_device(local_rank)
12
+ torch.distributed.init_process_group('nccl', init_method='env://')
13
+
14
+
15
+ def get_dist_info():
16
+ if dist.is_available():
17
+ initialized = dist.is_initialized()
18
+ else:
19
+ initialized = False
20
+ if initialized:
21
+ rank = dist.get_rank()
22
+ world_size = dist.get_world_size()
23
+ else:
24
+ rank = 0
25
+ world_size = 1
26
+ return rank, world_size
27
+
28
+
29
+ if __name__ == '__main__':
30
+ now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--module", type=str, help="module name", default="inference")
33
+ parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0)
34
+ args, unknown = parser.parse_known_args()
35
+ inference_api = importlib.import_module(args.module, package=None)
36
+
37
+ inference_parser = inference_api.get_parser()
38
+ inference_args, unknown = inference_parser.parse_known_args()
39
+
40
+ seed_everything(inference_args.seed)
41
+ setup_dist(args.local_rank)
42
+ torch.backends.cudnn.benchmark = True
43
+ rank, gpu_num = get_dist_info()
44
+
45
+ print("@CoLVDM Inference [rank%d]: %s"%(rank, now))
46
+ inference_api.run_inference(inference_args, gpu_num, rank)
scripts/evaluation/funcs.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, glob
2
+ import numpy as np
3
+ from collections import OrderedDict
4
+ from decord import VideoReader, cpu
5
+ import cv2
6
+
7
+ import torch
8
+ import torchvision
9
+ sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
10
+ from lvdm.models.samplers.ddim import DDIMSampler
11
+
12
+
13
+ def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\
14
+ cfg_scale=1.0, temporal_cfg_scale=None, **kwargs):
15
+ ddim_sampler = DDIMSampler(model)
16
+ uncond_type = model.uncond_type
17
+ batch_size = noise_shape[0]
18
+
19
+ ## construct unconditional guidance
20
+ if cfg_scale != 1.0:
21
+ if uncond_type == "empty_seq":
22
+ prompts = batch_size * [""]
23
+ #prompts = N * T * [""] ## if is_imgbatch=True
24
+ uc_emb = model.get_learned_conditioning(prompts)
25
+ elif uncond_type == "zero_embed":
26
+ c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
27
+ uc_emb = torch.zeros_like(c_emb)
28
+
29
+ ## process image embedding token
30
+ if hasattr(model, 'embedder'):
31
+ uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device)
32
+ ## img: b c h w >> b l c
33
+ uc_img = model.get_image_embeds(uc_img)
34
+ uc_emb = torch.cat([uc_emb, uc_img], dim=1)
35
+
36
+ if isinstance(cond, dict):
37
+ uc = {key:cond[key] for key in cond.keys()}
38
+ uc.update({'c_crossattn': [uc_emb]})
39
+ else:
40
+ uc = uc_emb
41
+ else:
42
+ uc = None
43
+
44
+ x_T = None
45
+ batch_variants = []
46
+ #batch_variants1, batch_variants2 = [], []
47
+ for _ in range(n_samples):
48
+ if ddim_sampler is not None:
49
+ kwargs.update({"clean_cond": True})
50
+ samples, _ = ddim_sampler.sample(S=ddim_steps,
51
+ conditioning=cond,
52
+ batch_size=noise_shape[0],
53
+ shape=noise_shape[1:],
54
+ verbose=False,
55
+ unconditional_guidance_scale=cfg_scale,
56
+ unconditional_conditioning=uc,
57
+ eta=ddim_eta,
58
+ temporal_length=noise_shape[2],
59
+ conditional_guidance_scale_temporal=temporal_cfg_scale,
60
+ x_T=x_T,
61
+ **kwargs
62
+ )
63
+ ## reconstruct from latent to pixel space
64
+ batch_images = model.decode_first_stage_2DAE(samples)
65
+ batch_variants.append(batch_images)
66
+ ## batch, <samples>, c, t, h, w
67
+ batch_variants = torch.stack(batch_variants, dim=1)
68
+ return batch_variants
69
+
70
+
71
+ def get_filelist(data_dir, ext='*'):
72
+ file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext))
73
+ file_list.sort()
74
+ return file_list
75
+
76
+ def get_dirlist(path):
77
+ list = []
78
+ if (os.path.exists(path)):
79
+ files = os.listdir(path)
80
+ for file in files:
81
+ m = os.path.join(path,file)
82
+ if (os.path.isdir(m)):
83
+ list.append(m)
84
+ list.sort()
85
+ return list
86
+
87
+
88
+ def load_model_checkpoint(model, ckpt):
89
+ def load_checkpoint(model, ckpt, full_strict):
90
+ state_dict = torch.load(ckpt, map_location="cpu")
91
+ try:
92
+ ## deepspeed
93
+ new_pl_sd = OrderedDict()
94
+ for key in state_dict['module'].keys():
95
+ new_pl_sd[key[16:]]=state_dict['module'][key]
96
+ model.load_state_dict(new_pl_sd, strict=full_strict)
97
+ except:
98
+ if "state_dict" in list(state_dict.keys()):
99
+ state_dict = state_dict["state_dict"]
100
+ model.load_state_dict(state_dict, strict=full_strict)
101
+ return model
102
+ load_checkpoint(model, ckpt, full_strict=True)
103
+ print('>>> model checkpoint loaded.')
104
+ return model
105
+
106
+
107
+ def load_prompts(prompt_file):
108
+ f = open(prompt_file, 'r')
109
+ prompt_list = []
110
+ for idx, line in enumerate(f.readlines()):
111
+ l = line.strip()
112
+ if len(l) != 0:
113
+ prompt_list.append(l)
114
+ f.close()
115
+ return prompt_list
116
+
117
+
118
+ def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16):
119
+ '''
120
+ Notice about some special cases:
121
+ 1. video_frames=-1 means to take all the frames (with fs=1)
122
+ 2. when the total video frames is less than required, padding strategy will be used (repreated last frame)
123
+ '''
124
+ fps_list = []
125
+ batch_tensor = []
126
+ assert frame_stride > 0, "valid frame stride should be a positive interge!"
127
+ for filepath in filepath_list:
128
+ padding_num = 0
129
+ vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
130
+ fps = vidreader.get_avg_fps()
131
+ total_frames = len(vidreader)
132
+ max_valid_frames = (total_frames-1) // frame_stride + 1
133
+ if video_frames < 0:
134
+ ## all frames are collected: fs=1 is a must
135
+ required_frames = total_frames
136
+ frame_stride = 1
137
+ else:
138
+ required_frames = video_frames
139
+ query_frames = min(required_frames, max_valid_frames)
140
+ frame_indices = [frame_stride*i for i in range(query_frames)]
141
+
142
+ ## [t,h,w,c] -> [c,t,h,w]
143
+ frames = vidreader.get_batch(frame_indices)
144
+ frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
145
+ frame_tensor = (frame_tensor / 255. - 0.5) * 2
146
+ if max_valid_frames < required_frames:
147
+ padding_num = required_frames - max_valid_frames
148
+ frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1)
149
+ print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.')
150
+ batch_tensor.append(frame_tensor)
151
+ sample_fps = int(fps/frame_stride)
152
+ fps_list.append(sample_fps)
153
+
154
+ return torch.stack(batch_tensor, dim=0)
155
+
156
+ from PIL import Image
157
+ def load_image_batch(filepath_list, image_size=(256,256)):
158
+ batch_tensor = []
159
+ for filepath in filepath_list:
160
+ _, filename = os.path.split(filepath)
161
+ _, ext = os.path.splitext(filename)
162
+ if ext == '.mp4':
163
+ vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0])
164
+ frame = vidreader.get_batch([0])
165
+ img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float()
166
+ elif ext == '.png' or ext == '.jpg':
167
+ img = Image.open(filepath).convert("RGB")
168
+ rgb_img = np.array(img, np.float32)
169
+ #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR)
170
+ #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
171
+ rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR)
172
+ img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float()
173
+ else:
174
+ print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]')
175
+ raise NotImplementedError
176
+ img_tensor = (img_tensor / 255. - 0.5) * 2
177
+ batch_tensor.append(img_tensor)
178
+ return torch.stack(batch_tensor, dim=0)
179
+
180
+
181
+ def save_videos(batch_tensors, savedir, filenames, fps=10):
182
+ # b,samples,c,t,h,w
183
+ n_samples = batch_tensors.shape[1]
184
+ for idx, vid_tensor in enumerate(batch_tensors):
185
+ video = vid_tensor.detach().cpu()
186
+ video = torch.clamp(video.float(), -1., 1.)
187
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
188
+ frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
189
+ grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
190
+ grid = (grid + 1.0) / 2.0
191
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
192
+ savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
193
+ torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
194
+
scripts/evaluation/inference.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob, yaml, math, random
2
+ import datetime, time
3
+ import numpy as np
4
+ from omegaconf import OmegaConf
5
+ from collections import OrderedDict
6
+ from tqdm import trange, tqdm
7
+ from einops import repeat
8
+ from einops import rearrange, repeat
9
+ from functools import partial
10
+ import torch
11
+ from pytorch_lightning import seed_everything
12
+
13
+ from funcs import load_model_checkpoint, load_prompts, load_image_batch, get_filelist, save_videos
14
+ from funcs import batch_ddim_sampling
15
+ from utils.utils import instantiate_from_config
16
+
17
+
18
+ def get_parser():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--seed", type=int, default=20230211, help="seed for seed_everything")
21
+ parser.add_argument("--mode", default="base", type=str, help="which kind of inference mode: {'base', 'i2v'}")
22
+ parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path")
23
+ parser.add_argument("--config", type=str, help="config (yaml) path")
24
+ parser.add_argument("--prompt_file", type=str, default=None, help="a text file containing many prompts")
25
+ parser.add_argument("--savedir", type=str, default=None, help="results saving path")
26
+ parser.add_argument("--savefps", type=str, default=10, help="video fps to generate")
27
+ parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
28
+ parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
29
+ parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
30
+ parser.add_argument("--bs", type=int, default=1, help="batch size for inference")
31
+ parser.add_argument("--height", type=int, default=512, help="image height, in pixel space")
32
+ parser.add_argument("--width", type=int, default=512, help="image width, in pixel space")
33
+ parser.add_argument("--frames", type=int, default=-1, help="frames num to inference")
34
+ parser.add_argument("--fps", type=int, default=24)
35
+ parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance")
36
+ parser.add_argument("--unconditional_guidance_scale_temporal", type=float, default=None, help="temporal consistency guidance")
37
+ ## for conditional i2v only
38
+ parser.add_argument("--cond_input", type=str, default=None, help="data dir of conditional input")
39
+ return parser
40
+
41
+
42
+ def run_inference(args, gpu_num, gpu_no, **kwargs):
43
+ ## step 1: model config
44
+ ## -----------------------------------------------------------------
45
+ config = OmegaConf.load(args.config)
46
+ #data_config = config.pop("data", OmegaConf.create())
47
+ model_config = config.pop("model", OmegaConf.create())
48
+ model = instantiate_from_config(model_config)
49
+ model = model.cuda(gpu_no)
50
+ assert os.path.exists(args.ckpt_path), f"Error: checkpoint [{args.ckpt_path}] Not Found!"
51
+ model = load_model_checkpoint(model, args.ckpt_path)
52
+ model.eval()
53
+
54
+ ## sample shape
55
+ assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!"
56
+ ## latent noise shape
57
+ h, w = args.height // 8, args.width // 8
58
+ frames = model.temporal_length if args.frames < 0 else args.frames
59
+ channels = model.channels
60
+
61
+ ## saving folders
62
+ os.makedirs(args.savedir, exist_ok=True)
63
+
64
+ ## step 2: load data
65
+ ## -----------------------------------------------------------------
66
+ assert os.path.exists(args.prompt_file), "Error: prompt file NOT Found!"
67
+ prompt_list = load_prompts(args.prompt_file)
68
+ num_samples = len(prompt_list)
69
+ filename_list = [f"{id+1:04d}" for id in range(num_samples)]
70
+
71
+ samples_split = num_samples // gpu_num
72
+ residual_tail = num_samples % gpu_num
73
+ print(f'[rank:{gpu_no}] {samples_split}/{num_samples} samples loaded.')
74
+ indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1)))
75
+ if gpu_no == 0 and residual_tail != 0:
76
+ indices = indices + list(range(num_samples-residual_tail, num_samples))
77
+ prompt_list_rank = [prompt_list[i] for i in indices]
78
+
79
+ ## conditional input
80
+ if args.mode == "i2v":
81
+ ## each video or frames dir per prompt
82
+ cond_inputs = get_filelist(args.cond_input, ext='[mpj][pn][4gj]') # '[mpj][pn][4gj]'
83
+ assert len(cond_inputs) == num_samples, f"Error: conditional input ({len(cond_inputs)}) NOT match prompt ({num_samples})!"
84
+ filename_list = [f"{os.path.split(cond_inputs[id])[-1][:-4]}" for id in range(num_samples)]
85
+ cond_inputs_rank = [cond_inputs[i] for i in indices]
86
+
87
+ filename_list_rank = [filename_list[i] for i in indices]
88
+
89
+ ## step 3: run over samples
90
+ ## -----------------------------------------------------------------
91
+ start = time.time()
92
+ n_rounds = len(prompt_list_rank) // args.bs
93
+ n_rounds = n_rounds+1 if len(prompt_list_rank) % args.bs != 0 else n_rounds
94
+ for idx in range(0, n_rounds):
95
+ print(f'[rank:{gpu_no}] batch-{idx+1} ({args.bs})x{args.n_samples} ...')
96
+ idx_s = idx*args.bs
97
+ idx_e = min(idx_s+args.bs, len(prompt_list_rank))
98
+ batch_size = idx_e - idx_s
99
+ filenames = filename_list_rank[idx_s:idx_e]
100
+ noise_shape = [batch_size, channels, frames, h, w]
101
+ fps = torch.tensor([args.fps]*batch_size).to(model.device).long()
102
+
103
+ prompts = prompt_list_rank[idx_s:idx_e]
104
+ if isinstance(prompts, str):
105
+ prompts = [prompts]
106
+ #prompts = batch_size * [""]
107
+ text_emb = model.get_learned_conditioning(prompts)
108
+
109
+ if args.mode == 'base':
110
+ cond = {"c_crossattn": [text_emb], "fps": fps}
111
+ elif args.mode == 'i2v':
112
+ #cond_images = torch.zeros(noise_shape[0],3,224,224).to(model.device)
113
+ cond_images = load_image_batch(cond_inputs_rank[idx_s:idx_e], (args.height, args.width))
114
+ cond_images = cond_images.to(model.device)
115
+ img_emb = model.get_image_embeds(cond_images)
116
+ imtext_cond = torch.cat([text_emb, img_emb], dim=1)
117
+ cond = {"c_crossattn": [imtext_cond], "fps": fps}
118
+ else:
119
+ raise NotImplementedError
120
+
121
+ ## inference
122
+ batch_samples = batch_ddim_sampling(model, cond, noise_shape, args.n_samples, \
123
+ args.ddim_steps, args.ddim_eta, args.unconditional_guidance_scale, **kwargs)
124
+ ## b,samples,c,t,h,w
125
+ save_videos(batch_samples, args.savedir, filenames, fps=args.savefps)
126
+
127
+ print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds")
128
+
129
+
130
+ if __name__ == '__main__':
131
+ now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
132
+ print("@CoLVDM Inference: %s"%now)
133
+ parser = get_parser()
134
+ args = parser.parse_args()
135
+ seed_everything(args.seed)
136
+ rank, gpu_num = 0, 1
137
+ run_inference(args, gpu_num, rank)
scripts/gradio/__pycache__/t2v_test.cpython-310.pyc ADDED
Binary file (3 kB). View file
 
scripts/gradio/i2v_test.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from omegaconf import OmegaConf
4
+ import torch
5
+ from scripts.evaluation.funcs import load_model_checkpoint, load_image_batch, save_videos, batch_ddim_sampling
6
+ from utils.utils import instantiate_from_config
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ class Image2Video():
10
+ def __init__(self,result_dir='./tmp/',gpu_num=1) -> None:
11
+ self.download_model()
12
+ self.result_dir = result_dir
13
+ if not os.path.exists(self.result_dir):
14
+ os.mkdir(self.result_dir)
15
+ ckpt_path='checkpoints/i2v_512_v1/model.ckpt'
16
+ config_file='configs/inference_i2v_512_v1.0.yaml'
17
+ config = OmegaConf.load(config_file)
18
+ model_config = config.pop("model", OmegaConf.create())
19
+ model_config['params']['unet_config']['params']['use_checkpoint']=False
20
+ model_list = []
21
+ for gpu_id in range(gpu_num):
22
+ model = instantiate_from_config(model_config)
23
+ # model = model.cuda(gpu_id)
24
+ assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
25
+ model = load_model_checkpoint(model, ckpt_path)
26
+ model.eval()
27
+ model_list.append(model)
28
+ self.model_list = model_list
29
+ self.save_fps = 8
30
+
31
+ def get_image(self, image, prompt, steps=50, cfg_scale=12.0, eta=1.0, fps=16):
32
+ torch.cuda.empty_cache()
33
+ print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
34
+ start = time.time()
35
+ gpu_id=0
36
+ if steps > 60:
37
+ steps = 60
38
+ model = self.model_list[gpu_id]
39
+ model = model.cuda()
40
+ batch_size=1
41
+ channels = model.model.diffusion_model.in_channels
42
+ frames = model.temporal_length
43
+ h, w = 320 // 8, 512 // 8
44
+ noise_shape = [batch_size, channels, frames, h, w]
45
+
46
+ # text cond
47
+ text_emb = model.get_learned_conditioning([prompt])
48
+
49
+ # img cond
50
+ img_tensor = torch.from_numpy(image).permute(2, 0, 1).float()
51
+ img_tensor = (img_tensor / 255. - 0.5) * 2
52
+ img_tensor = img_tensor.unsqueeze(0)
53
+ cond_images = img_tensor.to(model.device)
54
+ img_emb = model.get_image_embeds(cond_images)
55
+ imtext_cond = torch.cat([text_emb, img_emb], dim=1)
56
+ cond = {"c_crossattn": [imtext_cond], "fps": fps}
57
+
58
+ ## inference
59
+ batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
60
+ ## b,samples,c,t,h,w
61
+ prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
62
+ prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
63
+ prompt_str=prompt_str[:30]
64
+
65
+ save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
66
+ print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
67
+ model = model.cpu()
68
+ return os.path.join(self.result_dir, f"{prompt_str}.mp4")
69
+
70
+ def download_model(self):
71
+ REPO_ID = 'VideoCrafter/Image2Video-512'
72
+ filename_list = ['model.ckpt']
73
+ if not os.path.exists('./checkpoints/i2v_512_v1/'):
74
+ os.makedirs('./checkpoints/i2v_512_v1/')
75
+ for filename in filename_list:
76
+ local_file = os.path.join('./checkpoints/i2v_512_v1/', filename)
77
+ if not os.path.exists(local_file):
78
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/i2v_512_v1/', local_dir_use_symlinks=False)
79
+
80
+ if __name__ == '__main__':
81
+ i2v = Image2Video()
82
+ video_path = i2v.get_image('prompts/i2v_prompts/horse.png','horses are walking on the grassland')
83
+ print('done', video_path)
scripts/gradio/t2v_test.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from omegaconf import OmegaConf
4
+ import torch
5
+ from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling
6
+ from utils.utils import instantiate_from_config
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ class Text2Video():
10
+ def __init__(self,result_dir='./tmp/',gpu_num=1) -> None:
11
+ self.download_model()
12
+ self.result_dir = result_dir
13
+ if not os.path.exists(self.result_dir):
14
+ os.mkdir(self.result_dir)
15
+ ckpt_path='checkpoints/base_512_v2/model.ckpt'
16
+ config_file='configs/inference_t2v_512_v2.0.yaml'
17
+ config = OmegaConf.load(config_file)
18
+ model_config = config.pop("model", OmegaConf.create())
19
+ model_config['params']['unet_config']['params']['use_checkpoint']=False
20
+ model_list = []
21
+ for gpu_id in range(gpu_num):
22
+ model = instantiate_from_config(model_config)
23
+ # model = model.cuda(gpu_id)
24
+ assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
25
+ model = load_model_checkpoint(model, ckpt_path)
26
+ model.eval()
27
+ model_list.append(model)
28
+ self.model_list = model_list
29
+ self.save_fps = 8
30
+
31
+ def get_prompt(self, prompt, steps=50, cfg_scale=12.0, eta=1.0, fps=16):
32
+ torch.cuda.empty_cache()
33
+ print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
34
+ start = time.time()
35
+ gpu_id=0
36
+ if steps > 60:
37
+ steps = 60
38
+ model = self.model_list[gpu_id]
39
+ model = model.cuda()
40
+ batch_size=1
41
+ channels = model.model.diffusion_model.in_channels
42
+ frames = model.temporal_length
43
+ h, w = 320 // 8, 512 // 8
44
+ noise_shape = [batch_size, channels, frames, h, w]
45
+
46
+ # text cond
47
+ text_emb = model.get_learned_conditioning([prompt])
48
+ cond = {"c_crossattn": [text_emb], "fps": fps}
49
+
50
+ ## inference
51
+ batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
52
+ ## b,samples,c,t,h,w
53
+ prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
54
+ prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
55
+ prompt_str=prompt_str[:30]
56
+
57
+ save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
58
+ print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
59
+ model=model.cpu()
60
+ return os.path.join(self.result_dir, f"{prompt_str}.mp4")
61
+
62
+ def download_model(self):
63
+ REPO_ID = 'VideoCrafter/VideoCrafter2'
64
+ filename_list = ['model.ckpt']
65
+ if not os.path.exists('./checkpoints/base_512_v2/'):
66
+ os.makedirs('./checkpoints/base_512_v2/')
67
+ for filename in filename_list:
68
+ local_file = os.path.join('./checkpoints/base_512_v2/', filename)
69
+
70
+ if not os.path.exists(local_file):
71
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/base_512_v2/', local_dir_use_symlinks=False)
72
+
73
+
74
+ if __name__ == '__main__':
75
+ t2v = Text2Video()
76
+ video_path = t2v.get_prompt('a black swan swims on the pond')
77
+ print('done', video_path)
scripts/run_text2video.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name="7videos"
2
+
3
+ ckpt='checkpoints/base_512_v2/model.ckpt'
4
+ config='configs/inference_t2v_512_v2.0.yaml'
5
+
6
+ prompt_file="prompts/test_prompts.txt"
7
+ res_dir="results"
8
+
9
+ python3 scripts/evaluation/inference.py \
10
+ --seed 123 \
11
+ --mode 'base' \
12
+ --ckpt_path $ckpt \
13
+ --config $config \
14
+ --savedir $res_dir/$name \
15
+ --n_samples 1 \
16
+ --bs 1 --height 320 --width 512 \
17
+ --unconditional_guidance_scale 5.0 \
18
+ --unconditional_guidance_scale_temporal 5.0 \
19
+ --ddim_steps 50 \
20
+ --ddim_eta 1.0 \
21
+ --prompt_file $prompt_file \
22
+ --frames 20 \
23
+ --savefps 4