Spaces:
Runtime error
Runtime error
import spaces | |
import argparse | |
import torch | |
import tempfile | |
import os | |
import cv2 | |
import numpy as np | |
import gradio as gr | |
import torchvision.transforms.functional as F | |
import matplotlib.pyplot as plt | |
import matplotlib as mpl | |
from omegaconf import OmegaConf | |
from mpl_toolkits.mplot3d.art3d import Poly3DCollection | |
from inference_cameractrl import get_relative_pose, ray_condition, get_pipeline | |
from cameractrl.utils.util import save_videos_grid | |
cv2.setNumThreads(1) | |
mpl.use('agg') | |
#### Description #### | |
title = r"""<h1 align="center">CameraCtrl: Enabling Camera Control for Video Diffusion Models</h1>""" | |
subtitle = r"""<h2 align="center">CameraCtrl Image2Video with <a href='https://arxiv.org/abs/2311.15127' target='_blank'> <b>Stable Video Diffusion (SVD)</b> </a> <a href='https://huggingface.co/stabilityai/stable-video-diffusion-img2vid' target='_blank'> <b> model </b> </a> </h2>""" | |
description = r""" | |
<b>Official Gradio demo</b> for <a href='https://github.com/hehao13/CameraCtrl' target='_blank'><b>CameraCtrl: Enabling Camera Control for Video Diffusion Models</b></a>.<br> | |
CameraCtrl is capable of precisely controlling the camera trajectory during the video generation process.<br> | |
Note that, with SVD, CameraCtrl only support Image2Video now.<br> | |
""" | |
closing_words = r""" | |
--- | |
If you are interested in this demo or CameraCtrl is helpful for you, please give us a ⭐ of the <a href='https://github.com/hehao13/CameraCtrl' target='_blank'> CameraCtrl</a> Github Repo ! | |
[![GitHub Stars](https://img.shields.io/github/stars/hehao13/CameraCtrl | |
)](https://github.com/hehao13/CameraCtrl) | |
--- | |
📝 **Citation** | |
<br> | |
If you find our paper or code is useful for your research, please consider citing: | |
```bibtex | |
@article{he2024cameractrl, | |
title={CameraCtrl: Enabling Camera Control for Text-to-Video Generation}, | |
author={Hao He and Yinghao Xu and Yuwei Guo and Gordon Wetzstein and Bo Dai and Hongsheng Li and Ceyuan Yang}, | |
journal={arXiv preprint arXiv:2404.02101}, | |
year={2024} | |
} | |
``` | |
📧 **Contact** | |
<br> | |
If you have any questions, please feel free to contact me at <b>[email protected]</b>. | |
**Acknowledgement** | |
<br> | |
We thank <a href='https://wzhouxiff.github.io/projects/MotionCtrl/' target='_blank'><b>MotionCtrl</b></a> and <a href='https://huggingface.co/spaces/lllyasviel/IC-Light' target='_blank'><b>IC-Light</b></a> for their gradio codes.<br> | |
""" | |
RESIZE_MODES = ['Resize then Center Crop', 'Directly resize'] | |
CAMERA_TRAJECTORY_MODES = ["Provided Camera Trajectories", "Custom Camera Trajectories"] | |
height = 320 | |
width = 576 | |
num_frames = 14 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
config = "configs/train_cameractrl/svd_320_576_cameractrl.yaml" | |
model_id = "stabilityai/stable-video-diffusion-img2vid" | |
ckpt = "checkpoints/CameraCtrl_svdxt.ckpt" | |
if not os.path.exists(ckpt): | |
os.makedirs("checkpoints", exist_ok=True) | |
os.system("wget -c https://huggingface.co/hehao13/CameraCtrl_SVD_ckpts/resolve/main/CameraCtrl_svd.ckpt?download=true") | |
os.system("mv CameraCtrl_svd.ckpt?download=true checkpoints/CameraCtrl_svdxt.ckpt") | |
model_config = OmegaConf.load(config) | |
pipeline = get_pipeline(model_id, "unet", model_config['down_block_types'], model_config['up_block_types'], | |
model_config['pose_encoder_kwargs'], model_config['attention_processor_kwargs'], | |
ckpt, True, device) | |
examples = [ | |
[ | |
"assets/example_condition_images/A_tiny_finch_on_a_branch_with_spring_flowers_on_background..png", | |
"assets/pose_files/0bf152ef84195293.txt", | |
"Trajectory 1" | |
], | |
[ | |
"assets/example_condition_images/A_beautiful_fluffy_domestic_hen_sitting_on_white_eggs_in_a_brown_nest,_eggs_are_under_the_hen..png", | |
"assets/pose_files/0c9b371cc6225682.txt", | |
"Trajectory 2" | |
], | |
[ | |
"assets/example_condition_images/Rocky_coastline_with_crashing_waves..png", | |
"assets/pose_files/0c11dbe781b1c11c.txt", | |
"Trajectory 3" | |
], | |
[ | |
"assets/example_condition_images/A_lion_standing_on_a_surfboard_in_the_ocean..png", | |
"assets/pose_files/0f47577ab3441480.txt", | |
"Trajectory 4" | |
], | |
[ | |
"assets/example_condition_images/An_exploding_cheese_house..png", | |
"assets/pose_files/0f47577ab3441480.txt", | |
"Trajectory 4" | |
], | |
[ | |
"assets/example_condition_images/Dolphins_leaping_out_of_the_ocean_at_sunset..png", | |
"assets/pose_files/0f68374b76390082.txt", | |
"Trajectory 5" | |
], | |
[ | |
"assets/example_condition_images/Leaves_are_falling_from_trees..png", | |
"assets/pose_files/2c80f9eb0d3b2bb4.txt", | |
"Trajectory 6" | |
], | |
[ | |
"assets/example_condition_images/A_serene_mountain_lake_at_sunrise,_with_mist_hovering_over_the_water..png", | |
"assets/pose_files/2f25826f0d0ef09a.txt", | |
"Trajectory 7" | |
], | |
[ | |
"assets/example_condition_images/Fireworks_display_illuminating_the_night_sky..png", | |
"assets/pose_files/3f79dc32d575bcdc.txt", | |
"Trajectory 8" | |
], | |
[ | |
"assets/example_condition_images/A_car_running_on_Mars..png", | |
"assets/pose_files/4a2d6753676df096.txt", | |
"Trajectory 9" | |
], | |
] | |
class Camera(object): | |
def __init__(self, entry): | |
fx, fy, cx, cy = entry[1:5] | |
self.fx = fx | |
self.fy = fy | |
self.cx = cx | |
self.cy = cy | |
w2c_mat = np.array(entry[7:]).reshape(3, 4) | |
w2c_mat_4x4 = np.eye(4) | |
w2c_mat_4x4[:3, :] = w2c_mat | |
self.w2c_mat = w2c_mat_4x4 | |
self.c2w_mat = np.linalg.inv(w2c_mat_4x4) | |
class CameraPoseVisualizer: | |
def __init__(self, xlim, ylim, zlim): | |
self.fig = plt.figure(figsize=(18, 7)) | |
self.ax = self.fig.add_subplot(projection='3d') | |
self.plotly_data = None # plotly data traces | |
self.ax.set_aspect("auto") | |
self.ax.set_xlim(xlim) | |
self.ax.set_ylim(ylim) | |
self.ax.set_zlim(zlim) | |
self.ax.set_xlabel('x') | |
self.ax.set_ylabel('y') | |
self.ax.set_zlabel('z') | |
def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=9 / 16, base_xval=1, zval=3): | |
vertex_std = np.array([[0, 0, 0, 1], | |
[base_xval, -base_xval * hw_ratio, zval, 1], | |
[base_xval, base_xval * hw_ratio, zval, 1], | |
[-base_xval, base_xval * hw_ratio, zval, 1], | |
[-base_xval, -base_xval * hw_ratio, zval, 1]]) | |
vertex_transformed = vertex_std @ extrinsic.T | |
meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]], | |
[vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]], | |
[vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]], | |
[vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]], | |
[vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], | |
vertex_transformed[4, :-1]]] | |
color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map) | |
self.ax.add_collection3d( | |
Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35)) | |
def colorbar(self, max_frame_length): | |
cmap = mpl.cm.rainbow | |
norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length) | |
self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=self.ax, orientation='vertical', | |
label='Frame Indexes') | |
def show(self): | |
plt.title('Camera Trajectory') | |
plt.show() | |
def get_c2w(w2cs): | |
target_cam_c2w = np.array([ | |
[1, 0, 0, 0], | |
[0, 1, 0, 0], | |
[0, 0, 1, 0], | |
[0, 0, 0, 1] | |
]) | |
abs2rel = target_cam_c2w @ w2cs[0] | |
ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]] | |
camera_positions = np.asarray([c2w[:3, 3] for c2w in ret_poses]) # [n_frame, 3] | |
position_distances = [camera_positions[i] - camera_positions[i - 1] for i in range(1, len(camera_positions))] | |
xyz_max = np.max(camera_positions, axis=0) | |
xyz_min = np.min(camera_positions, axis=0) | |
xyz_ranges = xyz_max - xyz_min # [3, ] | |
max_range = np.max(xyz_ranges) | |
expected_xyz_ranges = 1 | |
scale_ratio = expected_xyz_ranges / max_range | |
scaled_position_distances = [dis * scale_ratio for dis in position_distances] # [n_frame - 1] | |
scaled_camera_positions = [camera_positions[0], ] | |
scaled_camera_positions.extend([camera_positions[0] + np.sum(np.asarray(scaled_position_distances[:i]), axis=0) | |
for i in range(1, len(camera_positions))]) | |
ret_poses = [np.concatenate( | |
(np.concatenate((ori_pose[:3, :3], cam_position[:, None]), axis=1), np.asarray([0, 0, 0, 1])[None]), axis=0) | |
for ori_pose, cam_position in zip(ret_poses, scaled_camera_positions)] | |
transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4) | |
ret_poses = [transform_matrix @ x for x in ret_poses] | |
return np.array(ret_poses, dtype=np.float32) | |
def visualize_trajectory(trajectory_file): | |
with open(trajectory_file, 'r') as f: | |
poses = f.readlines() | |
w2cs = [np.asarray([float(p) for p in pose.strip().split(' ')[7:]]).reshape(3, 4) for pose in poses[1:]] | |
num_frames = len(w2cs) | |
last_row = np.zeros((1, 4)) | |
last_row[0, -1] = 1.0 | |
w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs] | |
c2ws = get_c2w(w2cs) | |
visualizer = CameraPoseVisualizer([-1.2, 1.2], [-1.2, 1.2], [-1.2, 1.2]) | |
for frame_idx, c2w in enumerate(c2ws): | |
visualizer.extrinsic2pyramid(c2w, frame_idx / num_frames, hw_ratio=9 / 16, base_xval=0.02, zval=0.1) | |
visualizer.colorbar(num_frames) | |
return visualizer.fig | |
vis_traj = visualize_trajectory('assets/pose_files/0bf152ef84195293.txt') | |
def process_input_image(input_image, resize_mode): | |
global height, width | |
expected_hw_ratio = height / width | |
inp_w, inp_h = input_image.size | |
inp_hw_ratio = inp_h / inp_w | |
if inp_hw_ratio > expected_hw_ratio: | |
resized_height = inp_hw_ratio * width | |
resized_width = width | |
else: | |
resized_height = height | |
resized_width = height / inp_hw_ratio | |
resized_image = F.resize(input_image, size=[resized_height, resized_width]) | |
if resize_mode == RESIZE_MODES[0]: | |
return_image = F.center_crop(resized_image, output_size=[height, width]) | |
else: | |
return_image = resized_image | |
return gr.update(visible=True, value=return_image, height=height, width=width), gr.update(visible=True), gr.update( | |
visible=True), gr.update(visible=True), gr.update(visible=True) | |
def update_camera_trajectories(trajectory_mode): | |
if trajectory_mode == CAMERA_TRAJECTORY_MODES[0]: | |
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \ | |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ | |
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) | |
elif trajectory_mode == CAMERA_TRAJECTORY_MODES[1]: | |
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ | |
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \ | |
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) | |
def update_camera_args(trajectory_mode, provided_camera_trajectory, customized_trajectory_file): | |
if trajectory_mode == CAMERA_TRAJECTORY_MODES[0]: | |
res = "Provided " + str(provided_camera_trajectory) | |
else: | |
if customized_trajectory_file is None: | |
res = " " | |
else: | |
res = f"Customized trajectory file {customized_trajectory_file.name.split('/')[-1]}" | |
return res | |
def update_camera_args_reset(): | |
return " " | |
def update_trajectory_vis_plot(camera_trajectory_args, provided_camera_trajectory, customized_trajectory_file): | |
if 'Provided' in camera_trajectory_args: | |
if provided_camera_trajectory == "Trajectory 1": | |
trajectory_file_path = "assets/pose_files/0bf152ef84195293.txt" | |
elif provided_camera_trajectory == "Trajectory 2": | |
trajectory_file_path = "assets/pose_files/0c9b371cc6225682.txt" | |
elif provided_camera_trajectory == "Trajectory 3": | |
trajectory_file_path = "assets/pose_files/0c11dbe781b1c11c.txt" | |
elif provided_camera_trajectory == "Trajectory 4": | |
trajectory_file_path = "assets/pose_files/0f47577ab3441480.txt" | |
elif provided_camera_trajectory == "Trajectory 5": | |
trajectory_file_path = "assets/pose_files/0f68374b76390082.txt" | |
elif provided_camera_trajectory == "Trajectory 6": | |
trajectory_file_path = "assets/pose_files/2c80f9eb0d3b2bb4.txt" | |
elif provided_camera_trajectory == "Trajectory 7": | |
trajectory_file_path = "assets/pose_files/2f25826f0d0ef09a.txt" | |
elif provided_camera_trajectory == "Trajectory 8": | |
trajectory_file_path = "assets/pose_files/3f79dc32d575bcdc.txt" | |
else: | |
trajectory_file_path = "assets/pose_files/4a2d6753676df096.txt" | |
else: | |
trajectory_file_path = customized_trajectory_file.name | |
vis_traj = visualize_trajectory(trajectory_file_path) | |
return gr.update(visible=True), vis_traj, gr.update(visible=True), gr.update(visible=True), \ | |
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \ | |
gr.update(visible=True), gr.update(visible=True), trajectory_file_path | |
def update_set_button(): | |
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) | |
def update_buttons_for_example(example_image, example_traj_path, provided_traj_name): | |
global height, width | |
return_image = example_image | |
return gr.update(visible=True, value=return_image, height=height, width=width), gr.update(visible=True), \ | |
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \ | |
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), \ | |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), \ | |
gr.update(visible=True) | |
# @torch.inference_mode() | |
# @spaces.GPU(duration=150) | |
# def sample(condition_image, plucker_embedding, height, width, num_frames, num_inference_step, min_guidance_scale, max_guidance_scale, fps_id, generator): | |
# res = pipeline( | |
# image=condition_image, | |
# pose_embedding=plucker_embedding, | |
# height=height, | |
# width=width, | |
# num_frames=num_frames, | |
# num_inference_steps=num_inference_step, | |
# min_guidance_scale=min_guidance_scale, | |
# max_guidance_scale=max_guidance_scale, | |
# fps=fps_id, | |
# do_image_process=True, | |
# generator=generator, | |
# output_type='pt' | |
# ).frames[0].transpose(0, 1).cpu() | |
# | |
# temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name | |
# save_videos_grid(res[None], temporal_video_path, rescale=False) | |
# return temporal_video_path | |
def sample_video(condition_image, trajectory_file, num_inference_step, min_guidance_scale, max_guidance_scale, fps_id, seed): | |
global height, width, num_frames, device, pipeline | |
with open(trajectory_file, 'r') as f: | |
poses = f.readlines() | |
poses = [pose.strip().split(' ') for pose in poses[1:]] | |
cam_params = [[float(x) for x in pose] for pose in poses] | |
cam_params = [Camera(cam_param) for cam_param in cam_params] | |
sample_wh_ratio = width / height | |
pose_wh_ratio = cam_params[0].fy / cam_params[0].fx | |
if pose_wh_ratio > sample_wh_ratio: | |
resized_ori_w = height * pose_wh_ratio | |
for cam_param in cam_params: | |
cam_param.fx = resized_ori_w * cam_param.fx / width | |
else: | |
resized_ori_h = width / pose_wh_ratio | |
for cam_param in cam_params: | |
cam_param.fy = resized_ori_h * cam_param.fy / height | |
intrinsic = np.asarray([[cam_param.fx * width, | |
cam_param.fy * height, | |
cam_param.cx * width, | |
cam_param.cy * height] | |
for cam_param in cam_params], dtype=np.float32) | |
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] | |
c2ws = get_relative_pose(cam_params, zero_first_frame_scale=True) | |
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] | |
plucker_embedding = ray_condition(K, c2ws, height, width, device='cpu') # b f h w 6 | |
plucker_embedding = plucker_embedding.permute(0, 1, 4, 2, 3).contiguous().to(device=device) | |
generator = torch.Generator(device=device) | |
generator.manual_seed(int(seed)) | |
with torch.no_grad(): | |
sample = pipeline( | |
image=condition_image, | |
pose_embedding=plucker_embedding, | |
height=height, | |
width=width, | |
num_frames=num_frames, | |
num_inference_steps=num_inference_step, | |
min_guidance_scale=min_guidance_scale, | |
max_guidance_scale=max_guidance_scale, | |
fps=fps_id, | |
do_image_process=True, | |
generator=generator, | |
output_type='pt' | |
).frames[0].transpose(0, 1).cpu() | |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name | |
save_videos_grid(sample[None], temporal_video_path, rescale=False) | |
return temporal_video_path | |
# return sample(condition_image, plucker_embedding, height, width, num_frames, num_inference_step, min_guidance_scale, max_guidance_scale, fps_id, generator) | |
def main(args): | |
demo = gr.Blocks().queue() | |
with demo: | |
gr.Markdown(title) | |
gr.Markdown(subtitle) | |
gr.Markdown(description) | |
with gr.Column(): | |
# step1: Input condition image | |
step1_title = gr.Markdown("---\n## Step 1: Input an Image", show_label=False, visible=True) | |
step1_dec = gr.Markdown(f"\n 1. Upload an Image by `Drag` or Click `Upload Image`; \ | |
\n 2. Click `{RESIZE_MODES[0]}` or `{RESIZE_MODES[1]}` to select the image resize mode. \ | |
\n - `{RESIZE_MODES[0]}`: First resize the input image, then center crop it into the resolution of 320 x 576. \ | |
\n - `{RESIZE_MODES[1]}`: Only resize the input image, and keep the original aspect ratio.", | |
show_label=False, visible=True) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=2): | |
input_image = gr.Image(type='pil', interactive=True, elem_id='condition_image', | |
elem_classes='image', | |
visible=True) | |
with gr.Row(): | |
resize_crop_button = gr.Button(RESIZE_MODES[0], visible=True) | |
directly_resize_button = gr.Button(RESIZE_MODES[1], visible=True) | |
with gr.Column(scale=2): | |
processed_image = gr.Image(type='pil', interactive=False, elem_id='processed_image', | |
elem_classes='image', visible=False) | |
# step2: Select camera trajectory | |
step2_camera_trajectory = gr.Markdown("---\n## Step 2: Select the camera trajectory", show_label=False, | |
visible=False) | |
step2_camera_trajectory_des = gr.Markdown(f"\n - `{CAMERA_TRAJECTORY_MODES[0]}`: Including 9 camera trajectories extracted from the test set of RealEstate10K dataset, each has 25 frames. \ | |
\n - `{CAMERA_TRAJECTORY_MODES[1]}`: You can provide the customized camera trajectories in the txt file.", | |
show_label=False, visible=False) | |
with gr.Row(equal_height=True): | |
provide_trajectory_button = gr.Button(CAMERA_TRAJECTORY_MODES[0], visible=False) | |
customized_trajectory_button = gr.Button(CAMERA_TRAJECTORY_MODES[1], visible=False) | |
with gr.Row(): | |
with gr.Column(): | |
provided_camera_trajectory = gr.Markdown(f"---\n### {CAMERA_TRAJECTORY_MODES[0]}", show_label=False, | |
visible=False) | |
provided_camera_trajectory_des = gr.Markdown(f"\n 1. Click one of the provide camera trajectories, such as `Trajectory 1`; \ | |
\n 2. Click `Visualize Trajectory` to visualize the camera trajectory; \ | |
\n 3. Click `Reset Trajectory` to reset the camera trajectory. ", | |
show_label=False, visible=False) | |
customized_camera_trajectory = gr.Markdown(f"---\n### {CAMERA_TRAJECTORY_MODES[1]}", | |
show_label=False, | |
visible=False) | |
customized_run_status = gr.Markdown(f"\n 1. Input the txt file containing camera trajectory. \ | |
\n 2. Click `Visualize Trajectory` to visualize the camera trajectory; \ | |
\n 3. Click `Reset Trajectory` to reset the camera trajectory. ", | |
show_label=False, visible=False) | |
with gr.Row(): | |
provided_trajectories = gr.Dropdown( | |
["Trajectory 1", "Trajectory 2", "Trajectory 3", "Trajectory 4", "Trajectory 5", | |
"Trajectory 6", "Trajectory 7", "Trajectory 8", "Trajectory 9"], | |
label="Provided Trajectories", interactive=True, visible=False) | |
with gr.Row(): | |
customized_camera_trajectory_file = gr.File( | |
label="Upload customized camera trajectory (in .txt format).", visible=False, interactive=True) | |
with gr.Row(): | |
camera_args = gr.Textbox(value=" ", label="Camera Trajectory Name", visible=False) | |
camera_trajectory_path = gr.Textbox(value=" ", visible=False) | |
with gr.Row(): | |
camera_trajectory_vis = gr.Button(value="Visualize Camera Trajectory", visible=False) | |
camera_trajectory_reset = gr.Button(value="Reset Camera Trajectory", visible=False) | |
with gr.Column(): | |
vis_camera_trajectory = gr.Plot(vis_traj, label='Camera Trajectory', visible=False) | |
# step3: Set inference parameters | |
with gr.Row(): | |
with gr.Column(): | |
step3_title = gr.Markdown(f"---\n## Step3: Setting the inference hyper-parameters.", visible=False) | |
step3_des = gr.Markdown( | |
f"\n 1. Set the mumber of inference step; \ | |
\n 2. Set the seed; \ | |
\n 3. Set the minimum guidance scale and the maximum guidance scale; \ | |
\n 4. Set the fps; \ | |
\n - Please refer to the SVD paper for the meaning of the last three parameter", | |
visible=False) | |
with gr.Row(): | |
with gr.Column(): | |
num_inference_steps = gr.Number(value=25, label='Number Inference Steps', step=1, interactive=True, | |
visible=False) | |
with gr.Column(): | |
seed = gr.Number(value=42, label='Seed', minimum=1, interactive=True, visible=False, step=1) | |
with gr.Column(): | |
min_guidance_scale = gr.Number(value=1.0, label='Minimum Guidance Scale', minimum=1.0, step=0.5, | |
interactive=True, visible=False) | |
with gr.Column(): | |
max_guidance_scale = gr.Number(value=3.0, label='Maximum Guidance Scale', minimum=1.0, step=0.5, | |
interactive=True, visible=False) | |
with gr.Column(): | |
fps = gr.Number(value=7, label='FPS', minimum=1, step=1, interactive=True, visible=False) | |
with gr.Column(): | |
_ = gr.Button("Seed", visible=False) | |
with gr.Column(): | |
_ = gr.Button("Seed", visible=False) | |
with gr.Column(): | |
_ = gr.Button("Seed", visible=False) | |
with gr.Row(): | |
with gr.Column(): | |
_ = gr.Button("Set", visible=False) | |
with gr.Column(): | |
set_button = gr.Button("Set", visible=False) | |
with gr.Column(): | |
_ = gr.Button("Set", visible=False) | |
# step 4: Generate video | |
with gr.Row(): | |
with gr.Column(): | |
step4_title = gr.Markdown("---\n## Step4 Generating video", show_label=False, visible=False) | |
step4_des = gr.Markdown(f"\n - Click the `Start generation !` button to generate the video.; \ | |
\n - If the content of generated video is not very aligned with the condition image, try to increase the `Minimum Guidance Scale` and `Maximum Guidance Scale`. \ | |
\n - If the generated videos are distored, try to increase `FPS`.", | |
visible=False) | |
start_button = gr.Button(value="Start generation !", visible=False) | |
with gr.Column(): | |
generate_video = gr.Video(value=None, label="Generate Video", visible=False) | |
resize_crop_button.click(fn=process_input_image, inputs=[input_image, resize_crop_button], | |
outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des, | |
provide_trajectory_button, customized_trajectory_button]) | |
directly_resize_button.click(fn=process_input_image, inputs=[input_image, directly_resize_button], | |
outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des, | |
provide_trajectory_button, customized_trajectory_button]) | |
provide_trajectory_button.click(fn=update_camera_trajectories, inputs=[provide_trajectory_button], | |
outputs=[provided_camera_trajectory, provided_camera_trajectory_des, | |
provided_trajectories, | |
customized_camera_trajectory, customized_run_status, | |
customized_camera_trajectory_file, | |
camera_args, camera_trajectory_vis, camera_trajectory_reset]) | |
customized_trajectory_button.click(fn=update_camera_trajectories, inputs=[customized_trajectory_button], | |
outputs=[provided_camera_trajectory, provided_camera_trajectory_des, | |
provided_trajectories, | |
customized_camera_trajectory, customized_run_status, | |
customized_camera_trajectory_file, | |
camera_args, camera_trajectory_vis, camera_trajectory_reset]) | |
provided_trajectories.change(fn=update_camera_args, inputs=[provide_trajectory_button, provided_trajectories, customized_camera_trajectory_file], | |
outputs=[camera_args]) | |
customized_camera_trajectory_file.change(fn=update_camera_args, inputs=[customized_trajectory_button, provided_trajectories, customized_camera_trajectory_file], | |
outputs=[camera_args]) | |
camera_trajectory_reset.click(fn=update_camera_args_reset, inputs=None, outputs=[camera_args]) | |
camera_trajectory_vis.click(fn=update_trajectory_vis_plot, inputs=[camera_args, provided_trajectories, customized_camera_trajectory_file], | |
outputs=[vis_camera_trajectory, vis_camera_trajectory, step3_title, step3_des, | |
num_inference_steps, min_guidance_scale, max_guidance_scale, fps, | |
seed, set_button, camera_trajectory_path]) | |
set_button.click(fn=update_set_button, inputs=None, outputs=[step4_title, step4_des, start_button, generate_video]) | |
start_button.click(fn=sample_video, inputs=[processed_image, camera_trajectory_path, num_inference_steps, | |
min_guidance_scale, max_guidance_scale, fps, seed], | |
outputs=[generate_video]) | |
# set example | |
gr.Markdown("## Examples") | |
gr.Markdown("\n Choosing the one of the following examples to get a quick start, by selecting an example, " | |
"we will set the condition image and camera trajectory automatically. " | |
"Then, you can click the `Visualize Camera Trajectory` button to visualize the camera trajectory.") | |
gr.Examples( | |
fn=update_buttons_for_example, | |
run_on_click=True, | |
cache_examples=False, | |
examples=examples, | |
inputs=[input_image, camera_args, provided_trajectories], | |
outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des, provide_trajectory_button, | |
customized_trajectory_button, | |
provided_camera_trajectory, provided_camera_trajectory_des, provided_trajectories, | |
customized_camera_trajectory, customized_run_status, customized_camera_trajectory_file, | |
camera_args, camera_trajectory_vis, camera_trajectory_reset] | |
) | |
with gr.Row(): | |
gr.Markdown(closing_words) | |
demo.launch(**args) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--listen', default='0.0.0.0') | |
parser.add_argument('--broswer', action='store_true') | |
parser.add_argument('--share', action='store_true') | |
args = parser.parse_args() | |
launch_kwargs = {'server_name': args.listen, | |
'inbrowser': args.broswer, | |
'share': args.share} | |
main(launch_kwargs) | |