Spaces:
Configuration error
Configuration error
# Copyright 2023 Natural Synthetics Inc. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import sys | |
sys.path.append("/") | |
import os | |
import argparse | |
import torch | |
from hotshot_xl.pipelines.hotshot_xl_pipeline import HotshotXLPipeline | |
from hotshot_xl.pipelines.hotshot_xl_controlnet_pipeline import HotshotXLControlNetPipeline | |
from hotshot_xl.models.unet import UNet3DConditionModel | |
import torchvision.transforms as transforms | |
from einops import rearrange | |
from hotshot_xl.utils import save_as_gif, save_as_mp4, extract_gif_frames_from_midpoint, scale_aspect_fill | |
from torch import autocast | |
from diffusers import ControlNetModel | |
from contextlib import contextmanager | |
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | |
from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler | |
SCHEDULERS = { | |
'EulerAncestralDiscreteScheduler': EulerAncestralDiscreteScheduler, | |
'EulerDiscreteScheduler': EulerDiscreteScheduler, | |
'default': None, | |
# add more here | |
} | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Hotshot-XL inference") | |
parser.add_argument("--pretrained_path", type=str, default="hotshotco/Hotshot-XL") | |
parser.add_argument("--xformers", action="store_true") | |
parser.add_argument("--spatial_unet_base", type=str) | |
parser.add_argument("--lora", type=str) | |
parser.add_argument("--output", type=str, required=True) | |
parser.add_argument("--steps", type=int, default=30) | |
parser.add_argument("--prompt", type=str, | |
default="a bulldog in the captains chair of a spaceship, hd, high quality") | |
parser.add_argument("--negative_prompt", type=str, default="blurry") | |
parser.add_argument("--seed", type=int, default=455) | |
parser.add_argument("--width", type=int, default=672) | |
parser.add_argument("--height", type=int, default=384) | |
parser.add_argument("--target_width", type=int, default=512) | |
parser.add_argument("--target_height", type=int, default=512) | |
parser.add_argument("--og_width", type=int, default=1920) | |
parser.add_argument("--og_height", type=int, default=1080) | |
parser.add_argument("--video_length", type=int, default=8) | |
parser.add_argument("--video_duration", type=int, default=1000) | |
parser.add_argument("--low_vram_mode", action="store_true") | |
parser.add_argument('--scheduler', type=str, default='EulerAncestralDiscreteScheduler', | |
help='Name of the scheduler to use') | |
parser.add_argument("--control_type", type=str, default=None, choices=["depth", "canny"]) | |
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7) | |
parser.add_argument("--control_guidance_start", type=float, default=0.0) | |
parser.add_argument("--control_guidance_end", type=float, default=1.0) | |
parser.add_argument("--gif", type=str, default=None) | |
parser.add_argument("--precision", type=str, default='f16', choices=[ | |
'f16', 'f32', 'bf16' | |
]) | |
parser.add_argument("--autocast", type=str, default=None, choices=[ | |
'f16', 'bf16' | |
]) | |
return parser.parse_args() | |
to_pil = transforms.ToPILImage() | |
def to_pil_images(video_frames: torch.Tensor, output_type='pil'): | |
video_frames = rearrange(video_frames, "b c f w h -> b f c w h") | |
bsz = video_frames.shape[0] | |
images = [] | |
for i in range(bsz): | |
video = video_frames[i] | |
for j in range(video.shape[0]): | |
if output_type == "pil": | |
images.append(to_pil(video[j])) | |
else: | |
images.append(video[j]) | |
return images | |
def maybe_auto_cast(data_type): | |
if data_type: | |
with autocast("cuda", dtype=data_type): | |
yield | |
else: | |
yield | |
def main(): | |
args = parse_args() | |
if args.control_type and not args.gif: | |
raise ValueError("Controlnet specified but you didn't specify a gif!") | |
if args.gif and not args.control_type: | |
print("warning: gif was specified but no control type was specified. gif will be ignored.") | |
output_dir = os.path.dirname(args.output) | |
if output_dir: | |
os.makedirs(output_dir, exist_ok=True) | |
device = torch.device("cuda") | |
control_net_model_pretrained_path = None | |
if args.control_type: | |
control_type_to_model_map = { | |
"canny": "diffusers/controlnet-canny-sdxl-1.0", | |
"depth": "diffusers/controlnet-depth-sdxl-1.0", | |
} | |
control_net_model_pretrained_path = control_type_to_model_map[args.control_type] | |
data_type = torch.float32 | |
if args.precision == 'f16': | |
data_type = torch.half | |
elif args.precision == 'f32': | |
data_type = torch.float32 | |
elif args.precision == 'bf16': | |
data_type = torch.bfloat16 | |
pipe_line_args = { | |
"torch_dtype": data_type, | |
"use_safetensors": True | |
} | |
PipelineClass = HotshotXLPipeline | |
if control_net_model_pretrained_path: | |
PipelineClass = HotshotXLControlNetPipeline | |
pipe_line_args['controlnet'] = \ | |
ControlNetModel.from_pretrained(control_net_model_pretrained_path, torch_dtype=data_type) | |
if args.spatial_unet_base: | |
unet_3d = UNet3DConditionModel.from_pretrained(args.pretrained_path, subfolder="unet", torch_dtype=data_type).to(device) | |
unet = UNet3DConditionModel.from_pretrained_spatial(args.spatial_unet_base).to(device, dtype=data_type) | |
temporal_layers = {} | |
unet_3d_sd = unet_3d.state_dict() | |
for k, v in unet_3d_sd.items(): | |
if 'temporal' in k: | |
temporal_layers[k] = v | |
unet.load_state_dict(temporal_layers, strict=False) | |
pipe_line_args['unet'] = unet | |
del unet_3d_sd | |
del unet_3d | |
del temporal_layers | |
pipe = PipelineClass.from_pretrained(args.pretrained_path, **pipe_line_args).to(device) | |
if args.lora: | |
pipe.load_lora_weights(args.lora) | |
SchedulerClass = SCHEDULERS[args.scheduler] | |
if SchedulerClass is not None: | |
pipe.scheduler = SchedulerClass.from_config(pipe.scheduler.config) | |
if args.xformers: | |
pipe.enable_xformers_memory_efficient_attention() | |
generator = torch.Generator().manual_seed(args.seed) if args.seed else None | |
autocast_type = None | |
if args.autocast == 'f16': | |
autocast_type = torch.half | |
elif args.autocast == 'bf16': | |
autocast_type = torch.bfloat16 | |
if type(pipe) is HotshotXLControlNetPipeline: | |
kwargs = {} | |
else: | |
kwargs = { | |
"low_vram_mode": args.low_vram_mode | |
} | |
if args.gif and type(pipe) is HotshotXLControlNetPipeline: | |
kwargs['control_images'] = [ | |
scale_aspect_fill(img, args.width, args.height).convert("RGB") \ | |
for img in | |
extract_gif_frames_from_midpoint(args.gif, fps=args.video_length, target_duration=args.video_duration) | |
] | |
kwargs['controlnet_conditioning_scale'] = args.controlnet_conditioning_scale | |
kwargs['control_guidance_start'] = args.control_guidance_start | |
kwargs['control_guidance_end'] = args.control_guidance_end | |
with maybe_auto_cast(autocast_type): | |
images = pipe(args.prompt, | |
negative_prompt=args.negative_prompt, | |
width=args.width, | |
height=args.height, | |
original_size=(args.og_width, args.og_height), | |
target_size=(args.target_width, args.target_height), | |
num_inference_steps=args.steps, | |
video_length=args.video_length, | |
generator=generator, | |
output_type="tensor", **kwargs).videos | |
images = to_pil_images(images, output_type="pil") | |
if args.video_length > 1: | |
if args.output.split(".")[-1] == "gif": | |
save_as_gif(images, args.output, duration=args.video_duration // args.video_length) | |
else: | |
save_as_mp4(images, args.output, duration=args.video_duration // args.video_length) | |
else: | |
images[0].save(args.output, format='JPEG', quality=95) | |
if __name__ == "__main__": | |
main() | |