allegro-text2video / single_inference.py
fffiloni's picture
Upload 15 files
cdcfdd8 verified
raw
history blame
3.33 kB
import torch
import imageio
import os
import argparse
from diffusers.schedulers import EulerAncestralDiscreteScheduler
from transformers import T5EncoderModel, T5Tokenizer
from allegro.pipelines.pipeline_allegro import AllegroPipeline
from allegro.models.vae.vae_allegro import AllegroAutoencoderKL3D
from allegro.models.transformers.transformer_3d_allegro import AllegroTransformer3DModel
def single_inference(args):
dtype=torch.bfloat16
# vae have better formance in float32
vae = AllegroAutoencoderKL3D.from_pretrained(args.vae, torch_dtype=torch.float32).cuda()
vae.eval()
text_encoder = T5EncoderModel.from_pretrained(
args.text_encoder,
torch_dtype=dtype
)
text_encoder.eval()
tokenizer = T5Tokenizer.from_pretrained(
args.tokenizer,
)
scheduler = EulerAncestralDiscreteScheduler()
transformer = AllegroTransformer3DModel.from_pretrained(
args.dit,
torch_dtype=dtype
).cuda()
transformer.eval()
allegro_pipeline = AllegroPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer
).to("cuda:0")
positive_prompt = """
(masterpiece), (best quality), (ultra-detailed), (unwatermarked),
{}
emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo,
sharp focus, high budget, cinemascope, moody, epic, gorgeous
"""
negative_prompt = """
nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality,
low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.
"""
user_prompt = positive_prompt.format(args.user_prompt.lower().strip())
if args.enable_cpu_offload:
allegro_pipeline.enable_sequential_cpu_offload()
print("cpu offload enabled")
out_video = allegro_pipeline(
user_prompt,
negative_prompt = negative_prompt,
num_frames=88,
height=720,
width=1280,
num_inference_steps=args.num_sampling_steps,
guidance_scale=args.guidance_scale,
max_sequence_length=512,
generator = torch.Generator(device="cuda:0").manual_seed(args.seed)
).video[0]
imageio.mimwrite(args.save_path, out_video, fps=15, quality=8) # highest quality is 10, lowest is 0
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--user_prompt", type=str, default='')
parser.add_argument("--vae", type=str, default='')
parser.add_argument("--dit", type=str, default='')
parser.add_argument("--text_encoder", type=str, default='')
parser.add_argument("--tokenizer", type=str, default='')
parser.add_argument("--save_path", type=str, default="./output_videos/test_video.mp4")
parser.add_argument("--guidance_scale", type=float, default=7.5)
parser.add_argument("--num_sampling_steps", type=int, default=100)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--enable_cpu_offload", action='store_true')
args = parser.parse_args()
if os.path.dirname(args.save_path) != '' and (not os.path.exists(os.path.dirname(args.save_path))):
os.makedirs(os.path.dirname(args.save_path))
single_inference(args)