FoleyCrafter / app.py
unknown
cuda
7c9dc5d
raw
history blame
11.6 kB
import torch
import torchvision
import os
import os.path as osp
import spaces
import random
from argparse import ArgumentParser
from datetime import datetime
import gradio as gr
from foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy
from foleycrafter.pipelines.auffusion_pipeline import denormalize_spectrogram
from foleycrafter.pipelines.auffusion_pipeline import Generator
from foleycrafter.models.time_detector.model import VideoOnsetNet
from foleycrafter.models.specvqgan.onset_baseline.utils import torch_utils
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from huggingface_hub import snapshot_download
from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
import soundfile as sf
from moviepy.editor import AudioFileClip, VideoFileClip
os.environ['GRADIO_TEMP_DIR'] = './tmp'
sample_idx = 0
scheduler_dict = {
"DDIM": DDIMScheduler,
"Euler": EulerDiscreteScheduler,
"PNDM": PNDMScheduler,
}
css = """
.toolbutton {
margin-buttom: 0em 0em 0em 0em;
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.5em;
}
"""
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="example/config/base.yaml")
parser.add_argument("--server-name", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--share", action="store_true")
parser.add_argument("--save-path", default="samples")
args = parser.parse_args()
N_PROMPT = (
""
)
class FoleyController:
def __init__(self):
# config dirs
self.basedir = os.getcwd()
self.model_dir = os.path.join(self.basedir, "models")
self.savedir = os.path.join(self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
self.savedir_sample = os.path.join(self.savedir, "sample")
os.makedirs(self.savedir, exist_ok=True)
self.device = "cuda"
self.pipeline = None
self.loaded = False
self.load_model()
def load_model(self):
gr.Info("Start Load Models...")
print("Start Load Models...")
# download ckpt
pretrained_model_name_or_path = 'auffusion/auffusion-full-no-adapter'
if not os.path.isdir(pretrained_model_name_or_path):
pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path, local_dir='models/auffusion')
fc_ckpt = 'ymzhang319/FoleyCrafter'
if not os.path.isdir(fc_ckpt):
fc_ckpt = snapshot_download(fc_ckpt, local_dir='models/')
# set model config
temporal_ckpt_path = osp.join(self.model_dir, 'temporal_adapter.ckpt')
# load vocoder
vocoder_config_path= "./models/auffusion"
self.vocoder = Generator.from_pretrained(
vocoder_config_path,
subfolder="vocoder")
# load time detector
time_detector_ckpt = osp.join(osp.join(self.model_dir, 'timestamp_detector.pth.tar'))
time_detector = VideoOnsetNet(False)
self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, time_detector, strict=True)
self.time_detector = self.time_detector
self.pipeline = build_foleycrafter()
ckpt = torch.load(temporal_ckpt_path)
# load temporal adapter
if 'state_dict' in ckpt.keys():
ckpt = ckpt['state_dict']
load_gligen_ckpt = {}
for key, value in ckpt.items():
if key.startswith('module.'):
load_gligen_ckpt[key[len('module.'):]] = value
else:
load_gligen_ckpt[key] = value
m, u = self.pipeline.controlnet.load_state_dict(load_gligen_ckpt, strict=False)
print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
self.image_processor = CLIPImageProcessor()
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder='models/image_encoder')
self.pipeline.load_ip_adapter(fc_ckpt, subfolder='semantic', weight_name='semantic_adapter.bin', image_encoder_folder=None)
# move to gpu
self.time_detector = self.time_detector.to(self.device)
self.pipeline = self.pipeline.to(self.device)
self.vocoder = self.vocoder.to(self.device)
self.image_encoder = self.image_encoder.to(self.device)
gr.Info("Load Finish!")
print("Load Finish!")
self.loaded = True
return "Load"
@spaces.GPU
def foley(
self,
input_video,
prompt_textbox,
negative_prompt_textbox,
ip_adapter_scale,
temporal_scale,
sampler_dropdown,
sample_step_slider,
cfg_scale_slider,
seed_textbox,
):
vision_transform_list = [
torchvision.transforms.Resize((128, 128)),
torchvision.transforms.CenterCrop((112, 112)),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
video_transform = torchvision.transforms.Compose(vision_transform_list)
# if not self.loaded:
# raise gr.Error("Error with loading model")
generator = torch.Generator()
if seed_textbox != "":
torch.manual_seed(int(seed_textbox))
generator.manual_seed(int(seed_textbox))
max_frame_nums = 15
frames, duration = read_frames_with_moviepy(input_video, max_frame_nums=max_frame_nums)
if duration >= 10:
duration = 10
time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2).to(self.device)
time_frames = video_transform(time_frames)
time_frames = {'frames': time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)}
preds = self.time_detector(time_frames)
preds = torch.sigmoid(preds)
# duration
time_condition = [-1 if preds[0][int(i / (1024 / 10 * duration) * max_frame_nums)] < 0.5 else 1 for i in range(int(1024 / 10 * duration))]
time_condition = time_condition + [-1] * (1024 - len(time_condition))
# w -> b c h w
time_condition = torch.FloatTensor(time_condition).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 256, 1)
images = self.image_processor(images=frames, return_tensors="pt").to(self.device)
image_embeddings = self.image_encoder(**images).image_embeds
image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0)
neg_image_embeddings = torch.zeros_like(image_embeddings)
image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1)
self.pipeline.set_ip_adapter_scale(ip_adapter_scale)
sample = self.pipeline(
prompt=prompt_textbox,
negative_prompt=negative_prompt_textbox,
ip_adapter_image_embeds=image_embeddings,
image=time_condition,
controlnet_conditioning_scale=float(temporal_scale),
num_inference_steps=sample_step_slider,
height=256,
width=1024,
output_type="pt",
generator=generator,
)
name = 'output'
audio_img = sample.images[0]
audio = denormalize_spectrogram(audio_img)
audio = self.vocoder.inference(audio, lengths=160000)[0]
audio_save_path = osp.join(self.savedir_sample, 'audio')
os.makedirs(audio_save_path, exist_ok=True)
audio = audio[:int(duration * 16000)]
save_path = osp.join(audio_save_path, f'{name}.wav')
sf.write(save_path, audio, 16000)
audio = AudioFileClip(osp.join(audio_save_path, f'{name}.wav'))
video = VideoFileClip(input_video)
audio = audio.subclip(0, duration)
video.audio = audio
video = video.subclip(0, duration)
video.write_videofile(osp.join(self.savedir_sample, f'{name}.mp4'))
save_sample_path = os.path.join(self.savedir_sample, f"{name}.mp4")
return save_sample_path
controller = FoleyController()
def ui():
with gr.Blocks(css=css) as demo:
gr.HTML(
'<h1 style="height: 136px; display: flex; align-items: center; justify-content: space-around;"><span style="height: 100%; width:136px;"><img src="file/foleycrafter.png" alt="logo" style="height: 100%; width:auto; object-fit: contain; margin: 0px 0px; padding: 0px 0px;"></span><strong style="font-size: 40px;">FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds</strong></h1>'
)
with gr.Row():
gr.Markdown(
"<div align='center'><font size='5'><a href='https://foleycrafter.github.io/'>Project Page</a> &ensp;" # noqa
"<a href='https://arxiv.org/abs/xxxx.xxxxx/'>Paper</a> &ensp;"
"<a href='https://github.com/open-mmlab/foleycrafter'>Code</a> &ensp;"
"<a href='https://huggingface.co/spaces/ymzhang319/FoleyCrafter'>Demo</a> </font></div>"
)
with gr.Column(variant="panel"):
with gr.Row(equal_height=False):
with gr.Column():
with gr.Row():
init_img = gr.Video(label="Input Video")
with gr.Row():
prompt_textbox = gr.Textbox(value='', label="Prompt", lines=1)
with gr.Row():
negative_prompt_textbox = gr.Textbox(value=N_PROMPT, label="Negative prompt", lines=1)
with gr.Row():
sampler_dropdown = gr.Dropdown(
label="Sampling method",
choices=list(scheduler_dict.keys()),
value=list(scheduler_dict.keys())[0],
)
sample_step_slider = gr.Slider(
label="Sampling steps", value=25, minimum=10, maximum=100, step=1
)
cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
ip_adapter_scale = gr.Slider(label="Visual Content Scale", value=1.0, minimum=0, maximum=1)
temporal_scale = gr.Slider(label="Temporal Align Scale", value=0., minimum=0., maximum=1.0)
with gr.Row():
seed_textbox = gr.Textbox(label="Seed", value=42)
seed_button = gr.Button(value="\U0001f3b2", elem_classes="toolbutton")
seed_button.click(fn=lambda x: random.randint(1, 1e8), outputs=[seed_textbox], queue=False)
generate_button = gr.Button(value="Generate", variant="primary")
result_video = gr.Video(label="Generated Audio", interactive=False)
generate_button.click(
fn=controller.foley,
inputs=[
init_img,
prompt_textbox,
negative_prompt_textbox,
ip_adapter_scale,
temporal_scale,
sampler_dropdown,
sample_step_slider,
cfg_scale_slider,
seed_textbox,
],
outputs=[result_video],
)
return demo
if __name__ == "__main__":
demo = ui()
demo.queue(10)
demo.launch(server_name=args.server_name, server_port=args.port, share=args.share, allowed_paths=["./foleycrafter.png"])