guoyww commited on
Commit
b7349f6
1 Parent(s): e3e1d80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +524 -524
app.py CHANGED
@@ -1,633 +1,633 @@
1
 
2
 
3
- import os
4
- import json
5
- import torch
6
- import random
7
- import copy
8
 
9
- import gradio as gr
10
- from glob import glob
11
- from omegaconf import OmegaConf
12
- from datetime import datetime
13
- from safetensors import safe_open
14
 
15
- from diffusers import AutoencoderKL
16
- from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
17
- from diffusers.utils.import_utils import is_xformers_available
18
- from transformers import CLIPTextModel, CLIPTokenizer
19
 
20
- from animatediff.models.unet import UNet3DConditionModel
21
- from animatediff.pipelines.pipeline_animation import AnimationPipeline
22
- from animatediff.utils.util import save_videos_grid
23
- from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
24
- from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
25
 
26
 
27
- sample_idx = 0
28
- scheduler_dict = {
29
- "Euler": EulerDiscreteScheduler,
30
- "PNDM": PNDMScheduler,
31
- "DDIM": DDIMScheduler,
32
- }
33
 
34
- css = """
35
- .toolbutton {
36
- margin-buttom: 0em 0em 0em 0em;
37
- max-width: 2.5em;
38
- min-width: 2.5em !important;
39
- height: 2.5em;
40
- }
41
- """
42
 
43
- class AnimateController:
44
- def __init__(self):
45
 
46
- # config dirs
47
- self.basedir = os.getcwd()
48
- self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
49
- self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
50
- self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
51
- self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
52
- self.savedir_sample = os.path.join(self.savedir, "sample")
53
- os.makedirs(self.savedir, exist_ok=True)
54
 
55
- self.stable_diffusion_list = []
56
- self.motion_module_list = []
57
- self.personalized_model_list = []
58
 
59
- self.refresh_stable_diffusion()
60
- self.refresh_motion_module()
61
- self.refresh_personalized_model()
62
 
63
- # config models
64
- self.tokenizer = None
65
- self.text_encoder = None
66
- self.vae = None
67
- self.unet = None
68
- self.pipeline = None
69
- self.lora_model_state_dict = {}
70
 
71
- self.inference_config = OmegaConf.load("configs/inference/inference.yaml")
72
 
73
- def refresh_stable_diffusion(self):
74
- self.stable_diffusion_list = glob(os.path.join(self.stable_diffusion_dir, "*/"))
75
 
76
- def refresh_motion_module(self):
77
- motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
78
- self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
79
 
80
- def refresh_personalized_model(self):
81
- personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
82
- self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
83
-
84
- def update_stable_diffusion(self, stable_diffusion_dropdown):
85
- self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer")
86
- self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda()
87
- self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda()
88
- self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
89
- return gr.Dropdown.update()
90
 
91
- def update_motion_module(self, motion_module_dropdown):
92
- if self.unet is None:
93
- gr.Info(f"Please select a pretrained model path.")
94
- return gr.Dropdown.update(value=None)
95
- else:
96
- motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
97
- motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
98
- missing, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
99
- assert len(unexpected) == 0
100
- return gr.Dropdown.update()
101
 
102
- def update_base_model(self, base_model_dropdown):
103
- if self.unet is None:
104
- gr.Info(f"Please select a pretrained model path.")
105
- return gr.Dropdown.update(value=None)
106
- else:
107
- base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
108
- base_model_state_dict = {}
109
- with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
110
- for key in f.keys():
111
- base_model_state_dict[key] = f.get_tensor(key)
112
 
113
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
114
- self.vae.load_state_dict(converted_vae_checkpoint)
115
-
116
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
117
- self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
118
-
119
- self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
120
- return gr.Dropdown.update()
121
-
122
- def update_lora_model(self, lora_model_dropdown):
123
- lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
124
- self.lora_model_state_dict = {}
125
- if lora_model_dropdown == "none": pass
126
- else:
127
- with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
128
- for key in f.keys():
129
- self.lora_model_state_dict[key] = f.get_tensor(key)
130
- return gr.Dropdown.update()
131
 
132
- def animate(
133
- self,
134
- stable_diffusion_dropdown,
135
- motion_module_dropdown,
136
- base_model_dropdown,
137
- lora_alpha_slider,
138
- prompt_textbox,
139
- negative_prompt_textbox,
140
- sampler_dropdown,
141
- sample_step_slider,
142
- width_slider,
143
- length_slider,
144
- height_slider,
145
- cfg_scale_slider,
146
- seed_textbox
147
- ):
148
- if self.unet is None:
149
- raise gr.Error(f"Please select a pretrained model path.")
150
- if motion_module_dropdown == "":
151
- raise gr.Error(f"Please select a motion module.")
152
- # if base_model_dropdown == "":
153
- # raise gr.Error(f"Please select a base DreamBooth model.")
154
 
155
- if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
156
 
157
- pipeline = AnimationPipeline(
158
- vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
159
- scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
160
- ).to("cuda")
161
 
162
- if self.lora_model_state_dict != {}:
163
- print(f"Lora alpha: {lora_alpha_slider}")
164
- pipeline = convert_lora(copy.deepcopy(pipeline), self.lora_model_state_dict, alpha=lora_alpha_slider)
165
- pipeline.to("cuda")
166
 
167
- torch.cuda.empty_cache()
168
 
169
- seed_textbox = int(seed_textbox)
170
- if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(seed_textbox)
171
- else: torch.seed()
172
- seed = torch.initial_seed()
173
 
174
- sample = pipeline(
175
- prompt_textbox,
176
- negative_prompt = negative_prompt_textbox,
177
- num_inference_steps = sample_step_slider,
178
- guidance_scale = cfg_scale_slider,
179
- width = width_slider,
180
- height = height_slider,
181
- video_length = length_slider,
182
- ).videos
183
 
184
- save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
185
- save_videos_grid(sample, save_sample_path)
186
 
187
- sample_config = {
188
- "prompt": prompt_textbox,
189
- "n_prompt": negative_prompt_textbox,
190
- "sampler": sampler_dropdown,
191
- "num_inference_steps": sample_step_slider,
192
- "guidance_scale": cfg_scale_slider,
193
- "width": width_slider,
194
- "height": height_slider,
195
- "video_length": length_slider,
196
- "seed": seed
197
- }
198
- json_str = json.dumps(sample_config, indent=4)
199
- with open(os.path.join(self.savedir, "logs.json"), "a") as f:
200
- f.write(json_str)
201
- f.write("\n\n")
202
 
203
- del pipeline
204
- torch.cuda.empty_cache()
205
 
206
- return gr.Video.update(value=save_sample_path)
207
 
208
 
209
- controller = AnimateController()
210
 
211
 
212
- def ui():
213
- with gr.Blocks(css=css) as demo:
214
- gr.Markdown(
215
- """
216
- # [AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)
217
- Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br>
218
- [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/)
219
- """
220
- )
221
- with gr.Column(variant="panel"):
222
- gr.Markdown(
223
- """
224
- ### 1. Model checkpoints (select pretrained model path first).
225
- """
226
- )
227
- with gr.Row():
228
- stable_diffusion_dropdown = gr.Dropdown(
229
- label="Pretrained Model Path",
230
- choices=controller.stable_diffusion_list,
231
- interactive=True,
232
- )
233
- stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
234
 
235
- stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
236
- def update_stable_diffusion():
237
- controller.refresh_stable_diffusion()
238
- return gr.Dropdown.update(choices=controller.stable_diffusion_list)
239
- stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown])
240
-
241
- with gr.Row():
242
- motion_module_dropdown = gr.Dropdown(
243
- label="Select motion module",
244
- choices=controller.motion_module_list,
245
- interactive=True,
246
- )
247
- motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
248
 
249
- motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
250
- def update_motion_module():
251
- controller.refresh_motion_module()
252
- return gr.Dropdown.update(choices=controller.motion_module_list)
253
- motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
254
 
255
- base_model_dropdown = gr.Dropdown(
256
- label="Select base Dreambooth model (required)",
257
- choices=controller.personalized_model_list,
258
- interactive=True,
259
- )
260
- base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
261
 
262
- lora_model_dropdown = gr.Dropdown(
263
- label="Select LoRA model (optional)",
264
- choices=["none"] + controller.personalized_model_list,
265
- value="none",
266
- interactive=True,
267
- )
268
- lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
269
 
270
- lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.7, minimum=0, maximum=2, interactive=True)
271
 
272
- personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
273
- def update_personalized_model():
274
- controller.refresh_personalized_model()
275
- return [
276
- gr.Dropdown.update(choices=controller.personalized_model_list),
277
- gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
278
- ]
279
- personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
280
-
281
- with gr.Column(variant="panel"):
282
- gr.Markdown(
283
- """
284
- ### 2. Configs for AnimateDiff.
285
- """
286
- )
287
 
288
- prompt_textbox = gr.Textbox(label="Prompt", lines=2)
289
- negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2)
290
 
291
- with gr.Row().style(equal_height=False):
292
- with gr.Column():
293
- with gr.Row():
294
- sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
295
- sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1)
296
 
297
- width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64)
298
- height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64)
299
- length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1)
300
- cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
301
 
302
- with gr.Row():
303
- seed_textbox = gr.Textbox(label="Seed", value=-1)
304
- seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
305
- seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
306
 
307
- generate_button = gr.Button(value="Generate", variant='primary')
308
 
309
- result_video = gr.Video(label="Generated Animation", interactive=False)
310
-
311
- generate_button.click(
312
- fn=controller.animate,
313
- inputs=[
314
- stable_diffusion_dropdown,
315
- motion_module_dropdown,
316
- base_model_dropdown,
317
- lora_alpha_slider,
318
- prompt_textbox,
319
- negative_prompt_textbox,
320
- sampler_dropdown,
321
- sample_step_slider,
322
- width_slider,
323
- length_slider,
324
- height_slider,
325
- cfg_scale_slider,
326
- seed_textbox,
327
- ],
328
- outputs=[result_video]
329
- )
330
 
331
- return demo
332
 
333
 
334
- if __name__ == "__main__":
335
- demo = ui()
336
- demo.queue(max_size=20)
337
- demo.launch()
338
 
339
 
340
- # import os
341
- # import torch
342
- # import random
343
 
344
- # import gradio as gr
345
- # from glob import glob
346
- # from omegaconf import OmegaConf
347
- # from safetensors import safe_open
348
 
349
- # from diffusers import AutoencoderKL
350
- # from diffusers import EulerDiscreteScheduler, DDIMScheduler
351
- # from diffusers.utils.import_utils import is_xformers_available
352
- # from transformers import CLIPTextModel, CLIPTokenizer
353
 
354
- # from animatediff.models.unet import UNet3DConditionModel
355
- # from animatediff.pipelines.pipeline_animation import AnimationPipeline
356
- # from animatediff.utils.util import save_videos_grid
357
- # from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
358
 
359
 
360
- # pretrained_model_path = "models/StableDiffusion/stable-diffusion-v1-5"
361
- # inference_config_path = "configs/inference/inference.yaml"
362
 
363
- # css = """
364
- # .toolbutton {
365
- # margin-buttom: 0em 0em 0em 0em;
366
- # max-width: 2.5em;
367
- # min-width: 2.5em !important;
368
- # height: 2.5em;
369
- # }
370
- # """
371
 
372
- # examples = [
373
- # # 1-ToonYou
374
- # [
375
- # "toonyou_beta3.safetensors",
376
- # "mm_sd_v14.ckpt",
377
- # "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes",
378
- # "worst quality, low quality, nsfw, logo",
379
- # 512, 512, "13204175718326964000"
380
- # ],
381
- # # 2-Lyriel
382
- # [
383
- # "lyriel_v16.safetensors",
384
- # "mm_sd_v15.ckpt",
385
- # "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal",
386
- # "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular",
387
- # 512, 512, "6681501646976930000"
388
- # ],
389
- # # 3-RCNZ
390
- # [
391
- # "rcnzCartoon3d_v10.safetensors",
392
- # "mm_sd_v14.ckpt",
393
- # "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded",
394
- # "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
395
- # 512, 512, "2416282124261060"
396
- # ],
397
- # # 4-MajicMix
398
- # [
399
- # "majicmixRealistic_v5Preview.safetensors",
400
- # "mm_sd_v14.ckpt",
401
- # "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic",
402
- # "bad hand, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles",
403
- # 512, 512, "7132772652786303"
404
- # ],
405
- # # 5-RealisticVision
406
- # [
407
- # "realisticVisionV20_v20.safetensors",
408
- # "mm_sd_v15.ckpt",
409
- # "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3",
410
- # "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation",
411
- # 512, 512, "1490157606650685400"
412
- # ]
 
 
 
 
 
 
 
 
 
413
  # ]
414
 
415
- # # clean unrelated ckpts
416
- # # ckpts = [
417
- # # "realisticVisionV40_v20Novae.safetensors",
418
- # # "majicmixRealistic_v5Preview.safetensors",
419
- # # "rcnzCartoon3d_v10.safetensors",
420
- # # "lyriel_v16.safetensors",
421
- # # "toonyou_beta3.safetensors"
422
- # # ]
423
-
424
- # # for path in glob(os.path.join("models", "DreamBooth_LoRA", "*.safetensors")):
425
- # # for ckpt in ckpts:
426
- # # if path.endswith(ckpt): break
427
- # # else:
428
- # # print(f"### Cleaning {path} ...")
429
- # # os.system(f"rm -rf {path}")
430
 
431
- # # os.system(f"rm -rf {os.path.join('models', 'DreamBooth_LoRA', '*.safetensors')}")
432
 
433
- # # os.system(f"bash download_bashscripts/1-ToonYou.sh")
434
- # # os.system(f"bash download_bashscripts/2-Lyriel.sh")
435
- # # os.system(f"bash download_bashscripts/3-RcnzCartoon.sh")
436
- # # os.system(f"bash download_bashscripts/4-MajicMix.sh")
437
- # # os.system(f"bash download_bashscripts/5-RealisticVision.sh")
438
 
439
- # # clean Grdio cache
440
- # print(f"### Cleaning cached examples ...")
441
- # os.system(f"rm -rf gradio_cached_examples/")
442
 
443
 
444
- # class AnimateController:
445
- # def __init__(self):
446
 
447
- # # config dirs
448
- # self.basedir = os.getcwd()
449
- # self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
450
- # self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
451
- # self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
452
- # self.savedir = os.path.join(self.basedir, "samples")
453
- # os.makedirs(self.savedir, exist_ok=True)
454
 
455
- # self.base_model_list = []
456
- # self.motion_module_list = []
457
 
458
- # self.selected_base_model = None
459
- # self.selected_motion_module = None
460
 
461
- # self.refresh_motion_module()
462
- # self.refresh_personalized_model()
463
 
464
- # # config models
465
- # self.inference_config = OmegaConf.load(inference_config_path)
466
 
467
- # self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
468
- # self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda()
469
- # self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
470
- # self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
471
 
472
- # self.update_base_model(self.base_model_list[0])
473
- # self.update_motion_module(self.motion_module_list[0])
474
 
475
 
476
- # def refresh_motion_module(self):
477
- # motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
478
- # self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
479
 
480
- # def refresh_personalized_model(self):
481
- # base_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
482
- # self.base_model_list = [os.path.basename(p) for p in base_model_list]
483
 
484
 
485
- # def update_base_model(self, base_model_dropdown):
486
- # self.selected_base_model = base_model_dropdown
487
 
488
- # base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
489
- # base_model_state_dict = {}
490
- # with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
491
- # for key in f.keys(): base_model_state_dict[key] = f.get_tensor(key)
492
 
493
- # converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
494
- # self.vae.load_state_dict(converted_vae_checkpoint)
495
 
496
- # converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
497
- # self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
498
 
499
- # self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
500
- # return gr.Dropdown.update()
501
 
502
- # def update_motion_module(self, motion_module_dropdown):
503
- # self.selected_motion_module = motion_module_dropdown
504
 
505
- # motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
506
- # motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
507
- # _, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
508
- # assert len(unexpected) == 0
509
- # return gr.Dropdown.update()
510
 
511
 
512
- # def animate(
513
- # self,
514
- # base_model_dropdown,
515
- # motion_module_dropdown,
516
- # prompt_textbox,
517
- # negative_prompt_textbox,
518
- # width_slider,
519
- # height_slider,
520
- # seed_textbox,
521
- # ):
522
- # if self.selected_base_model != base_model_dropdown: self.update_base_model(base_model_dropdown)
523
- # if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
524
 
525
- # if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
526
 
527
- # pipeline = AnimationPipeline(
528
- # vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
529
- # scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
530
- # ).to("cuda")
531
 
532
- # if int(seed_textbox) > 0: seed = int(seed_textbox)
533
- # else: seed = random.randint(1, 1e16)
534
- # torch.manual_seed(int(seed))
535
 
536
- # assert seed == torch.initial_seed()
537
- # print(f"### seed: {seed}")
538
 
539
- # generator = torch.Generator(device="cuda")
540
- # generator.manual_seed(seed)
541
 
542
- # sample = pipeline(
543
- # prompt_textbox,
544
- # negative_prompt = negative_prompt_textbox,
545
- # num_inference_steps = 25,
546
- # guidance_scale = 8.,
547
- # width = width_slider,
548
- # height = height_slider,
549
- # video_length = 16,
550
- # generator = generator,
551
- # ).videos
552
 
553
- # save_sample_path = os.path.join(self.savedir, f"sample.mp4")
554
- # save_videos_grid(sample, save_sample_path)
555
 
556
- # json_config = {
557
- # "prompt": prompt_textbox,
558
- # "n_prompt": negative_prompt_textbox,
559
- # "width": width_slider,
560
- # "height": height_slider,
561
- # "seed": seed,
562
- # "base_model": base_model_dropdown,
563
- # "motion_module": motion_module_dropdown,
564
- # }
565
- # return gr.Video.update(value=save_sample_path), gr.Json.update(value=json_config)
566
 
567
 
568
- # controller = AnimateController()
569
 
570
 
571
- # def ui():
572
- # with gr.Blocks(css=css) as demo:
573
- # gr.Markdown(
574
- # """
575
- # # AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning
576
- # Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br>
577
- # [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/)
578
- # """
579
- # )
580
- # gr.Markdown(
581
- # """
582
- # ### Quick Start
583
- # 1. Select desired `Base DreamBooth Model`.
584
- # 2. Select `Motion Module` from `mm_sd_v14.ckpt` and `mm_sd_v15.ckpt`. We recommend trying both of them for the best results.
585
- # 3. Provide `Prompt` and `Negative Prompt` for each model. You are encouraged to refer to each model's webpage on CivitAI to learn how to write prompts for them. Below are the DreamBooth models in this demo. Click to visit their homepage.
586
- # - [`toonyou_beta3.safetensors`](https://civitai.com/models/30240?modelVersionId=78775)
587
- # - [`lyriel_v16.safetensors`](https://civitai.com/models/22922/lyriel)
588
- # - [`rcnzCartoon3d_v10.safetensors`](https://civitai.com/models/66347?modelVersionId=71009)
589
- # - [`majicmixRealistic_v5Preview.safetensors`](https://civitai.com/models/43331?modelVersionId=79068)
590
- # - [`realisticVisionV20_v20.safetensors`](https://civitai.com/models/4201?modelVersionId=29460)
591
- # 4. Click `Generate`, wait for ~1 min, and enjoy.
592
- # """
593
- # )
594
- # with gr.Row():
595
- # with gr.Column():
596
- # base_model_dropdown = gr.Dropdown( label="Base DreamBooth Model", choices=controller.base_model_list, value=controller.base_model_list[0], interactive=True )
597
- # motion_module_dropdown = gr.Dropdown( label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True )
598
 
599
- # base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
600
- # motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
601
 
602
- # prompt_textbox = gr.Textbox( label="Prompt", lines=3 )
603
- # negative_prompt_textbox = gr.Textbox( label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo")
604
 
605
- # with gr.Accordion("Advance", open=False):
606
- # with gr.Row():
607
- # width_slider = gr.Slider( label="Width", value=512, minimum=256, maximum=1024, step=64 )
608
- # height_slider = gr.Slider( label="Height", value=512, minimum=256, maximum=1024, step=64 )
609
- # with gr.Row():
610
- # seed_textbox = gr.Textbox( label="Seed", value=-1)
611
- # seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
612
- # seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e16)), inputs=[], outputs=[seed_textbox])
613
 
614
- # generate_button = gr.Button( value="Generate", variant='primary' )
615
 
616
- # with gr.Column():
617
- # result_video = gr.Video( label="Generated Animation", interactive=False )
618
- # json_config = gr.Json( label="Config", value=None )
619
 
620
- # inputs = [base_model_dropdown, motion_module_dropdown, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, seed_textbox]
621
- # outputs = [result_video, json_config]
622
 
623
- # generate_button.click( fn=controller.animate, inputs=inputs, outputs=outputs )
624
 
625
- # gr.Examples( fn=controller.animate, examples=examples, inputs=inputs, outputs=outputs, cache_examples=True )
626
 
627
- # return demo
628
 
629
 
630
- # if __name__ == "__main__":
631
- # demo = ui()
632
- # demo.queue(max_size=20)
633
- # demo.launch()
 
1
 
2
 
3
+ # import os
4
+ # import json
5
+ # import torch
6
+ # import random
7
+ # import copy
8
 
9
+ # import gradio as gr
10
+ # from glob import glob
11
+ # from omegaconf import OmegaConf
12
+ # from datetime import datetime
13
+ # from safetensors import safe_open
14
 
15
+ # from diffusers import AutoencoderKL
16
+ # from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
17
+ # from diffusers.utils.import_utils import is_xformers_available
18
+ # from transformers import CLIPTextModel, CLIPTokenizer
19
 
20
+ # from animatediff.models.unet import UNet3DConditionModel
21
+ # from animatediff.pipelines.pipeline_animation import AnimationPipeline
22
+ # from animatediff.utils.util import save_videos_grid
23
+ # from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
24
+ # from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
25
 
26
 
27
+ # sample_idx = 0
28
+ # scheduler_dict = {
29
+ # "Euler": EulerDiscreteScheduler,
30
+ # "PNDM": PNDMScheduler,
31
+ # "DDIM": DDIMScheduler,
32
+ # }
33
 
34
+ # css = """
35
+ # .toolbutton {
36
+ # margin-buttom: 0em 0em 0em 0em;
37
+ # max-width: 2.5em;
38
+ # min-width: 2.5em !important;
39
+ # height: 2.5em;
40
+ # }
41
+ # """
42
 
43
+ # class AnimateController:
44
+ # def __init__(self):
45
 
46
+ # # config dirs
47
+ # self.basedir = os.getcwd()
48
+ # self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
49
+ # self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
50
+ # self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
51
+ # self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
52
+ # self.savedir_sample = os.path.join(self.savedir, "sample")
53
+ # os.makedirs(self.savedir, exist_ok=True)
54
 
55
+ # self.stable_diffusion_list = []
56
+ # self.motion_module_list = []
57
+ # self.personalized_model_list = []
58
 
59
+ # self.refresh_stable_diffusion()
60
+ # self.refresh_motion_module()
61
+ # self.refresh_personalized_model()
62
 
63
+ # # config models
64
+ # self.tokenizer = None
65
+ # self.text_encoder = None
66
+ # self.vae = None
67
+ # self.unet = None
68
+ # self.pipeline = None
69
+ # self.lora_model_state_dict = {}
70
 
71
+ # self.inference_config = OmegaConf.load("configs/inference/inference.yaml")
72
 
73
+ # def refresh_stable_diffusion(self):
74
+ # self.stable_diffusion_list = glob(os.path.join(self.stable_diffusion_dir, "*/"))
75
 
76
+ # def refresh_motion_module(self):
77
+ # motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
78
+ # self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
79
 
80
+ # def refresh_personalized_model(self):
81
+ # personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
82
+ # self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
83
+
84
+ # def update_stable_diffusion(self, stable_diffusion_dropdown):
85
+ # self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer")
86
+ # self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda()
87
+ # self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda()
88
+ # self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
89
+ # return gr.Dropdown.update()
90
 
91
+ # def update_motion_module(self, motion_module_dropdown):
92
+ # if self.unet is None:
93
+ # gr.Info(f"Please select a pretrained model path.")
94
+ # return gr.Dropdown.update(value=None)
95
+ # else:
96
+ # motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
97
+ # motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
98
+ # missing, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
99
+ # assert len(unexpected) == 0
100
+ # return gr.Dropdown.update()
101
 
102
+ # def update_base_model(self, base_model_dropdown):
103
+ # if self.unet is None:
104
+ # gr.Info(f"Please select a pretrained model path.")
105
+ # return gr.Dropdown.update(value=None)
106
+ # else:
107
+ # base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
108
+ # base_model_state_dict = {}
109
+ # with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
110
+ # for key in f.keys():
111
+ # base_model_state_dict[key] = f.get_tensor(key)
112
 
113
+ # converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
114
+ # self.vae.load_state_dict(converted_vae_checkpoint)
115
+
116
+ # converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
117
+ # self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
118
+
119
+ # self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
120
+ # return gr.Dropdown.update()
121
+
122
+ # def update_lora_model(self, lora_model_dropdown):
123
+ # lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
124
+ # self.lora_model_state_dict = {}
125
+ # if lora_model_dropdown == "none": pass
126
+ # else:
127
+ # with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
128
+ # for key in f.keys():
129
+ # self.lora_model_state_dict[key] = f.get_tensor(key)
130
+ # return gr.Dropdown.update()
131
 
132
+ # def animate(
133
+ # self,
134
+ # stable_diffusion_dropdown,
135
+ # motion_module_dropdown,
136
+ # base_model_dropdown,
137
+ # lora_alpha_slider,
138
+ # prompt_textbox,
139
+ # negative_prompt_textbox,
140
+ # sampler_dropdown,
141
+ # sample_step_slider,
142
+ # width_slider,
143
+ # length_slider,
144
+ # height_slider,
145
+ # cfg_scale_slider,
146
+ # seed_textbox
147
+ # ):
148
+ # if self.unet is None:
149
+ # raise gr.Error(f"Please select a pretrained model path.")
150
+ # if motion_module_dropdown == "":
151
+ # raise gr.Error(f"Please select a motion module.")
152
+ # # if base_model_dropdown == "":
153
+ # # raise gr.Error(f"Please select a base DreamBooth model.")
154
 
155
+ # if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
156
 
157
+ # pipeline = AnimationPipeline(
158
+ # vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
159
+ # scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
160
+ # ).to("cuda")
161
 
162
+ # if self.lora_model_state_dict != {}:
163
+ # print(f"Lora alpha: {lora_alpha_slider}")
164
+ # pipeline = convert_lora(copy.deepcopy(pipeline), self.lora_model_state_dict, alpha=lora_alpha_slider)
165
+ # pipeline.to("cuda")
166
 
167
+ # torch.cuda.empty_cache()
168
 
169
+ # seed_textbox = int(seed_textbox)
170
+ # if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(seed_textbox)
171
+ # else: torch.seed()
172
+ # seed = torch.initial_seed()
173
 
174
+ # sample = pipeline(
175
+ # prompt_textbox,
176
+ # negative_prompt = negative_prompt_textbox,
177
+ # num_inference_steps = sample_step_slider,
178
+ # guidance_scale = cfg_scale_slider,
179
+ # width = width_slider,
180
+ # height = height_slider,
181
+ # video_length = length_slider,
182
+ # ).videos
183
 
184
+ # save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
185
+ # save_videos_grid(sample, save_sample_path)
186
 
187
+ # sample_config = {
188
+ # "prompt": prompt_textbox,
189
+ # "n_prompt": negative_prompt_textbox,
190
+ # "sampler": sampler_dropdown,
191
+ # "num_inference_steps": sample_step_slider,
192
+ # "guidance_scale": cfg_scale_slider,
193
+ # "width": width_slider,
194
+ # "height": height_slider,
195
+ # "video_length": length_slider,
196
+ # "seed": seed
197
+ # }
198
+ # json_str = json.dumps(sample_config, indent=4)
199
+ # with open(os.path.join(self.savedir, "logs.json"), "a") as f:
200
+ # f.write(json_str)
201
+ # f.write("\n\n")
202
 
203
+ # del pipeline
204
+ # torch.cuda.empty_cache()
205
 
206
+ # return gr.Video.update(value=save_sample_path)
207
 
208
 
209
+ # controller = AnimateController()
210
 
211
 
212
+ # def ui():
213
+ # with gr.Blocks(css=css) as demo:
214
+ # gr.Markdown(
215
+ # """
216
+ # # [AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)
217
+ # Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br>
218
+ # [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/)
219
+ # """
220
+ # )
221
+ # with gr.Column(variant="panel"):
222
+ # gr.Markdown(
223
+ # """
224
+ # ### 1. Model checkpoints (select pretrained model path first).
225
+ # """
226
+ # )
227
+ # with gr.Row():
228
+ # stable_diffusion_dropdown = gr.Dropdown(
229
+ # label="Pretrained Model Path",
230
+ # choices=controller.stable_diffusion_list,
231
+ # interactive=True,
232
+ # )
233
+ # stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
234
 
235
+ # stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
236
+ # def update_stable_diffusion():
237
+ # controller.refresh_stable_diffusion()
238
+ # return gr.Dropdown.update(choices=controller.stable_diffusion_list)
239
+ # stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown])
240
+
241
+ # with gr.Row():
242
+ # motion_module_dropdown = gr.Dropdown(
243
+ # label="Select motion module",
244
+ # choices=controller.motion_module_list,
245
+ # interactive=True,
246
+ # )
247
+ # motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
248
 
249
+ # motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
250
+ # def update_motion_module():
251
+ # controller.refresh_motion_module()
252
+ # return gr.Dropdown.update(choices=controller.motion_module_list)
253
+ # motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
254
 
255
+ # base_model_dropdown = gr.Dropdown(
256
+ # label="Select base Dreambooth model (required)",
257
+ # choices=controller.personalized_model_list,
258
+ # interactive=True,
259
+ # )
260
+ # base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
261
 
262
+ # lora_model_dropdown = gr.Dropdown(
263
+ # label="Select LoRA model (optional)",
264
+ # choices=["none"] + controller.personalized_model_list,
265
+ # value="none",
266
+ # interactive=True,
267
+ # )
268
+ # lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
269
 
270
+ # lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.7, minimum=0, maximum=2, interactive=True)
271
 
272
+ # personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
273
+ # def update_personalized_model():
274
+ # controller.refresh_personalized_model()
275
+ # return [
276
+ # gr.Dropdown.update(choices=controller.personalized_model_list),
277
+ # gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
278
+ # ]
279
+ # personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
280
+
281
+ # with gr.Column(variant="panel"):
282
+ # gr.Markdown(
283
+ # """
284
+ # ### 2. Configs for AnimateDiff.
285
+ # """
286
+ # )
287
 
288
+ # prompt_textbox = gr.Textbox(label="Prompt", lines=2)
289
+ # negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2)
290
 
291
+ # with gr.Row().style(equal_height=False):
292
+ # with gr.Column():
293
+ # with gr.Row():
294
+ # sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
295
+ # sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1)
296
 
297
+ # width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64)
298
+ # height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64)
299
+ # length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1)
300
+ # cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
301
 
302
+ # with gr.Row():
303
+ # seed_textbox = gr.Textbox(label="Seed", value=-1)
304
+ # seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
305
+ # seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
306
 
307
+ # generate_button = gr.Button(value="Generate", variant='primary')
308
 
309
+ # result_video = gr.Video(label="Generated Animation", interactive=False)
310
+
311
+ # generate_button.click(
312
+ # fn=controller.animate,
313
+ # inputs=[
314
+ # stable_diffusion_dropdown,
315
+ # motion_module_dropdown,
316
+ # base_model_dropdown,
317
+ # lora_alpha_slider,
318
+ # prompt_textbox,
319
+ # negative_prompt_textbox,
320
+ # sampler_dropdown,
321
+ # sample_step_slider,
322
+ # width_slider,
323
+ # length_slider,
324
+ # height_slider,
325
+ # cfg_scale_slider,
326
+ # seed_textbox,
327
+ # ],
328
+ # outputs=[result_video]
329
+ # )
330
 
331
+ # return demo
332
 
333
 
334
+ # if __name__ == "__main__":
335
+ # demo = ui()
336
+ # demo.queue(max_size=20)
337
+ # demo.launch()
338
 
339
 
340
+ import os
341
+ import torch
342
+ import random
343
 
344
+ import gradio as gr
345
+ from glob import glob
346
+ from omegaconf import OmegaConf
347
+ from safetensors import safe_open
348
 
349
+ from diffusers import AutoencoderKL
350
+ from diffusers import EulerDiscreteScheduler, DDIMScheduler
351
+ from diffusers.utils.import_utils import is_xformers_available
352
+ from transformers import CLIPTextModel, CLIPTokenizer
353
 
354
+ from animatediff.models.unet import UNet3DConditionModel
355
+ from animatediff.pipelines.pipeline_animation import AnimationPipeline
356
+ from animatediff.utils.util import save_videos_grid
357
+ from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
358
 
359
 
360
+ pretrained_model_path = "models/StableDiffusion/stable-diffusion-v1-5"
361
+ inference_config_path = "configs/inference/inference.yaml"
362
 
363
+ css = """
364
+ .toolbutton {
365
+ margin-buttom: 0em 0em 0em 0em;
366
+ max-width: 2.5em;
367
+ min-width: 2.5em !important;
368
+ height: 2.5em;
369
+ }
370
+ """
371
 
372
+ examples = [
373
+ # 1-ToonYou
374
+ [
375
+ "toonyou_beta3.safetensors",
376
+ "mm_sd_v14.ckpt",
377
+ "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes",
378
+ "worst quality, low quality, nsfw, logo",
379
+ 512, 512, "13204175718326964000"
380
+ ],
381
+ # 2-Lyriel
382
+ [
383
+ "lyriel_v16.safetensors",
384
+ "mm_sd_v15.ckpt",
385
+ "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal",
386
+ "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular",
387
+ 512, 512, "6681501646976930000"
388
+ ],
389
+ # 3-RCNZ
390
+ [
391
+ "rcnzCartoon3d_v10.safetensors",
392
+ "mm_sd_v14.ckpt",
393
+ "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded",
394
+ "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
395
+ 512, 512, "2416282124261060"
396
+ ],
397
+ # 4-MajicMix
398
+ [
399
+ "majicmixRealistic_v5Preview.safetensors",
400
+ "mm_sd_v14.ckpt",
401
+ "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic",
402
+ "bad hand, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles",
403
+ 512, 512, "7132772652786303"
404
+ ],
405
+ # 5-RealisticVision
406
+ [
407
+ "realisticVisionV20_v20.safetensors",
408
+ "mm_sd_v15.ckpt",
409
+ "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3",
410
+ "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation",
411
+ 512, 512, "1490157606650685400"
412
+ ]
413
+ ]
414
+
415
+ # clean unrelated ckpts
416
+ # ckpts = [
417
+ # "realisticVisionV40_v20Novae.safetensors",
418
+ # "majicmixRealistic_v5Preview.safetensors",
419
+ # "rcnzCartoon3d_v10.safetensors",
420
+ # "lyriel_v16.safetensors",
421
+ # "toonyou_beta3.safetensors"
422
  # ]
423
 
424
+ # for path in glob(os.path.join("models", "DreamBooth_LoRA", "*.safetensors")):
425
+ # for ckpt in ckpts:
426
+ # if path.endswith(ckpt): break
427
+ # else:
428
+ # print(f"### Cleaning {path} ...")
429
+ # os.system(f"rm -rf {path}")
 
 
 
 
 
 
 
 
 
430
 
431
+ # os.system(f"rm -rf {os.path.join('models', 'DreamBooth_LoRA', '*.safetensors')}")
432
 
433
+ # os.system(f"bash download_bashscripts/1-ToonYou.sh")
434
+ # os.system(f"bash download_bashscripts/2-Lyriel.sh")
435
+ # os.system(f"bash download_bashscripts/3-RcnzCartoon.sh")
436
+ # os.system(f"bash download_bashscripts/4-MajicMix.sh")
437
+ # os.system(f"bash download_bashscripts/5-RealisticVision.sh")
438
 
439
+ # clean Grdio cache
440
+ print(f"### Cleaning cached examples ...")
441
+ os.system(f"rm -rf gradio_cached_examples/")
442
 
443
 
444
+ class AnimateController:
445
+ def __init__(self):
446
 
447
+ # config dirs
448
+ self.basedir = os.getcwd()
449
+ self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
450
+ self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
451
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
452
+ self.savedir = os.path.join(self.basedir, "samples")
453
+ os.makedirs(self.savedir, exist_ok=True)
454
 
455
+ self.base_model_list = []
456
+ self.motion_module_list = []
457
 
458
+ self.selected_base_model = None
459
+ self.selected_motion_module = None
460
 
461
+ self.refresh_motion_module()
462
+ self.refresh_personalized_model()
463
 
464
+ # config models
465
+ self.inference_config = OmegaConf.load(inference_config_path)
466
 
467
+ self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
468
+ self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda()
469
+ self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
470
+ self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
471
 
472
+ self.update_base_model(self.base_model_list[0])
473
+ self.update_motion_module(self.motion_module_list[0])
474
 
475
 
476
+ def refresh_motion_module(self):
477
+ motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
478
+ self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
479
 
480
+ def refresh_personalized_model(self):
481
+ base_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
482
+ self.base_model_list = [os.path.basename(p) for p in base_model_list]
483
 
484
 
485
+ def update_base_model(self, base_model_dropdown):
486
+ self.selected_base_model = base_model_dropdown
487
 
488
+ base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
489
+ base_model_state_dict = {}
490
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
491
+ for key in f.keys(): base_model_state_dict[key] = f.get_tensor(key)
492
 
493
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
494
+ self.vae.load_state_dict(converted_vae_checkpoint)
495
 
496
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
497
+ self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
498
 
499
+ self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
500
+ return gr.Dropdown.update()
501
 
502
+ def update_motion_module(self, motion_module_dropdown):
503
+ self.selected_motion_module = motion_module_dropdown
504
 
505
+ motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
506
+ motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
507
+ _, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
508
+ assert len(unexpected) == 0
509
+ return gr.Dropdown.update()
510
 
511
 
512
+ def animate(
513
+ self,
514
+ base_model_dropdown,
515
+ motion_module_dropdown,
516
+ prompt_textbox,
517
+ negative_prompt_textbox,
518
+ width_slider,
519
+ height_slider,
520
+ seed_textbox,
521
+ ):
522
+ if self.selected_base_model != base_model_dropdown: self.update_base_model(base_model_dropdown)
523
+ if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
524
 
525
+ if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
526
 
527
+ pipeline = AnimationPipeline(
528
+ vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
529
+ scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
530
+ ).to("cuda")
531
 
532
+ if int(seed_textbox) > 0: seed = int(seed_textbox)
533
+ else: seed = random.randint(1, 1e16)
534
+ torch.manual_seed(int(seed))
535
 
536
+ assert seed == torch.initial_seed()
537
+ print(f"### seed: {seed}")
538
 
539
+ generator = torch.Generator(device="cuda")
540
+ generator.manual_seed(seed)
541
 
542
+ sample = pipeline(
543
+ prompt_textbox,
544
+ negative_prompt = negative_prompt_textbox,
545
+ num_inference_steps = 25,
546
+ guidance_scale = 8.,
547
+ width = width_slider,
548
+ height = height_slider,
549
+ video_length = 16,
550
+ generator = generator,
551
+ ).videos
552
 
553
+ save_sample_path = os.path.join(self.savedir, f"sample.mp4")
554
+ save_videos_grid(sample, save_sample_path)
555
 
556
+ json_config = {
557
+ "prompt": prompt_textbox,
558
+ "n_prompt": negative_prompt_textbox,
559
+ "width": width_slider,
560
+ "height": height_slider,
561
+ "seed": seed,
562
+ "base_model": base_model_dropdown,
563
+ "motion_module": motion_module_dropdown,
564
+ }
565
+ return gr.Video.update(value=save_sample_path), gr.Json.update(value=json_config)
566
 
567
 
568
+ controller = AnimateController()
569
 
570
 
571
+ def ui():
572
+ with gr.Blocks(css=css) as demo:
573
+ gr.Markdown(
574
+ """
575
+ # AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning
576
+ Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br>
577
+ [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/)
578
+ """
579
+ )
580
+ gr.Markdown(
581
+ """
582
+ ### Quick Start
583
+ 1. Select desired `Base DreamBooth Model`.
584
+ 2. Select `Motion Module` from `mm_sd_v14.ckpt` and `mm_sd_v15.ckpt`. We recommend trying both of them for the best results.
585
+ 3. Provide `Prompt` and `Negative Prompt` for each model. You are encouraged to refer to each model's webpage on CivitAI to learn how to write prompts for them. Below are the DreamBooth models in this demo. Click to visit their homepage.
586
+ - [`toonyou_beta3.safetensors`](https://civitai.com/models/30240?modelVersionId=78775)
587
+ - [`lyriel_v16.safetensors`](https://civitai.com/models/22922/lyriel)
588
+ - [`rcnzCartoon3d_v10.safetensors`](https://civitai.com/models/66347?modelVersionId=71009)
589
+ - [`majicmixRealistic_v5Preview.safetensors`](https://civitai.com/models/43331?modelVersionId=79068)
590
+ - [`realisticVisionV20_v20.safetensors`](https://civitai.com/models/4201?modelVersionId=29460)
591
+ 4. Click `Generate`, wait for ~1 min, and enjoy.
592
+ """
593
+ )
594
+ with gr.Row():
595
+ with gr.Column():
596
+ base_model_dropdown = gr.Dropdown( label="Base DreamBooth Model", choices=controller.base_model_list, value=controller.base_model_list[0], interactive=True )
597
+ motion_module_dropdown = gr.Dropdown( label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True )
598
 
599
+ base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
600
+ motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
601
 
602
+ prompt_textbox = gr.Textbox( label="Prompt", lines=3 )
603
+ negative_prompt_textbox = gr.Textbox( label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo")
604
 
605
+ with gr.Accordion("Advance", open=False):
606
+ with gr.Row():
607
+ width_slider = gr.Slider( label="Width", value=512, minimum=256, maximum=1024, step=64 )
608
+ height_slider = gr.Slider( label="Height", value=512, minimum=256, maximum=1024, step=64 )
609
+ with gr.Row():
610
+ seed_textbox = gr.Textbox( label="Seed", value=-1)
611
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
612
+ seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e16)), inputs=[], outputs=[seed_textbox])
613
 
614
+ generate_button = gr.Button( value="Generate", variant='primary' )
615
 
616
+ with gr.Column():
617
+ result_video = gr.Video( label="Generated Animation", interactive=False )
618
+ json_config = gr.Json( label="Config", value=None )
619
 
620
+ inputs = [base_model_dropdown, motion_module_dropdown, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, seed_textbox]
621
+ outputs = [result_video, json_config]
622
 
623
+ generate_button.click( fn=controller.animate, inputs=inputs, outputs=outputs )
624
 
625
+ gr.Examples( fn=controller.animate, examples=examples, inputs=inputs, outputs=outputs, cache_examples=True )
626
 
627
+ return demo
628
 
629
 
630
+ if __name__ == "__main__":
631
+ demo = ui()
632
+ demo.queue(max_size=20)
633
+ demo.launch()