DiffIR2VR / app.py
jimmycv07's picture
first commit
1de8821
raw
history blame
10.7 kB
import os
import cv2
import torch
import spaces
import imageio
import numpy as np
import gradio as gr
torch.jit.script = lambda f: f
import argparse
from utils.batch_inference import (
BSRInferenceLoop, BIDInferenceLoop
)
# import subprocess
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_example(task):
case = {
"dn": [
['examples/bus.mp4',],
['examples/koala.mp4',],
['examples/flamingo.mp4',],
['examples/rhino.mp4',],
['examples/elephant.mp4',],
['examples/sheep.mp4',],
['examples/dog-agility.mp4',],
# ['examples/dog-gooses.mp4',],
],
"sr": [
['examples/bus_sr.mp4',],
['examples/koala_sr.mp4',],
['examples/flamingo_sr.mp4',],
['examples/rhino_sr.mp4',],
['examples/elephant_sr.mp4',],
['examples/sheep_sr.mp4',],
['examples/dog-agility_sr.mp4',],
# ['examples/dog-gooses_sr.mp4',],
]
}
return case[task]
def update_prompt(input_video):
video_name = input_video.split('/')[-1]
return set_default_prompt(video_name)
# Map videos to corresponding images
video_to_image = {
'bus.mp4': ['examples_frames/bus'],
'koala.mp4': ['examples_frames/koala'],
'dog-gooses.mp4': ['examples_frames/dog-gooses'],
'flamingo.mp4': ['examples_frames/flamingo'],
'rhino.mp4': ['examples_frames/rhino'],
'elephant.mp4': ['examples_frames/elephant'],
'sheep.mp4': ['examples_frames/sheep'],
'dog-agility.mp4': ['examples_frames/dog-agility'],
'bus_sr.mp4': ['examples_frames/bus_sr'],
'koala_sr.mp4': ['examples_frames/koala_sr'],
'dog-gooses_sr.mp4': ['examples_frames/dog_gooses_sr'],
'flamingo_sr.mp4': ['examples_frames/flamingo_sr'],
'rhino_sr.mp4': ['examples_frames/rhino_sr'],
'elephant_sr.mp4': ['examples_frames/elephant_sr'],
'sheep_sr.mp4': ['examples_frames/sheep_sr'],
'dog-agility_sr.mp4': ['examples_frames/dog-agility_sr'],
}
def images_to_video(image_list, output_path, fps=10):
# Convert PIL Images to numpy arrays
frames = [np.array(img).astype(np.uint8) for img in image_list]
frames = frames[:20]
# Create video writer
writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
for frame in frames:
writer.append_data(frame)
writer.close()
@spaces.GPU(duration=120)
def DiffBIR_restore(input_video, prompt, sr_ratio, n_frames, n_steps, guidance_scale, seed, n_prompt, task):
video_name = input_video.split('/')[-1]
if video_name in video_to_image:
frames_path = video_to_image[video_name][0]
else:
return None
print(f"[INFO] input_video: {input_video}")
print(f"[INFO] Frames path: {frames_path}")
args = argparse.Namespace()
# args.task = True, choices=["sr", "dn", "fr", "fr_bg"]
args.task = task
args.upscale = sr_ratio
### sampling parameters
args.steps = n_steps
args.better_start = True
args.tiled = False
args.tile_size = 512
args.tile_stride = 256
args.pos_prompt = prompt
args.neg_prompt = n_prompt
args.cfg_scale = guidance_scale
### input parameters
args.input = frames_path
args.n_samples = 1
args.batch_size = 10
args.final_size = (480, 854)
args.config = "configs/inference/my_cldm.yaml"
### guidance parameters
args.guidance = False
args.g_loss = "w_mse"
args.g_scale = 0.0
args.g_start = 1001
args.g_stop = -1
args.g_space = "latent"
args.g_repeat = 1
### output parameters
args.output = " "
### common parameters
args.seed = seed
args.device = "cuda"
args.n_frames = n_frames
### latent control parameters
args.warp_period = [0, 0.1]
args.merge_period = [0, 0]
args.ToMe_period = [0, 1]
args.merge_ratio = [0.6, 0]
if args.task == "sr":
restored_vid_path = BSRInferenceLoop(args).run()
elif args.task == "dn":
restored_vid_path = BIDInferenceLoop(args).run()
torch.cuda.empty_cache()
return restored_vid_path
########
# demo #
########
intro = """
<div style="text-align:center">
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
DiffIR2VR - <small>Zero-Shot Video Restoration</small>
</h1>
<span>[<a target="_blank" href="https://jimmycv07.github.io/DiffIR2VR_web/">Project page</a>] [<a target="_blank" href="https://huggingface.co/papers/2406.06523">arXiv</a>]</span>
<div style="display:flex; justify-content: center;margin-top: 0.5em">Note that this page is a limited demo of DiffIR2VR. For more configurations, please visit our GitHub page. The code will be released soon!</div>
</div>
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML(intro)
with gr.Tab(label="Super-resolution with DiffBIR"):
with gr.Row():
input_video = gr.Video(label="Input Video")
output_video = gr.Video(label="Restored Video", interactive=False)
with gr.Row():
run_button = gr.Button("Restore your video !", visible=True)
with gr.Accordion('Advanced options', open=False):
prompt = gr.Textbox(
label="Prompt",
max_lines=1,
placeholder="describe your video content"
# value="bear, Van Gogh Style"
)
sr_ratio = gr.Slider(label='SR ratio',
minimum=1,
maximum=16,
value=4,
step=1)
n_frames = gr.Slider(label='Frames',
minimum=1,
maximum=60,
value=10,
step=1)
n_steps = gr.Slider(label='Steps',
minimum=1,
maximum=100,
value=10,
step=1)
guidance_scale = gr.Slider(label='Guidance Scale',
minimum=0.1,
maximum=30.0,
value=4.0,
step=0.1)
seed = gr.Slider(label='Seed',
minimum=-1,
maximum=1000,
step=1,
randomize=True)
n_prompt = gr.Textbox(
label='Negative Prompt',
value="low quality, blurry, low-resolution, noisy, unsharp, weird textures"
)
task = gr.Textbox(value="sr", visible=False)
# input_video.change(
# fn = update_prompt,
# inputs = [input_video],
# outputs = [prompt],
# queue = False)
run_button.click(fn = DiffBIR_restore,
inputs = [input_video,
prompt,
sr_ratio,
n_frames,
n_steps,
guidance_scale,
seed,
n_prompt,
task
],
outputs = [output_video]
)
gr.Examples(
examples=get_example("sr"),
label='Examples',
inputs=[input_video],
outputs=[output_video],
examples_per_page=7
)
with gr.Tab(label="Denoise with DiffBIR"):
with gr.Row():
input_video = gr.Video(label="Input Video")
output_video = gr.Video(label="Restored Video", interactive=False)
with gr.Row():
run_button = gr.Button("Restore your video !", visible=True)
with gr.Accordion('Advanced options', open=False):
prompt = gr.Textbox(
label="Prompt",
max_lines=1,
placeholder="describe your video content"
# value="bear, Van Gogh Style"
)
n_frames = gr.Slider(label='Frames',
minimum=1,
maximum=60,
value=10,
step=1)
n_steps = gr.Slider(label='Steps',
minimum=1,
maximum=100,
value=10,
step=1)
guidance_scale = gr.Slider(label='Guidance Scale',
minimum=0.1,
maximum=30.0,
value=4.0,
step=0.1)
seed = gr.Slider(label='Seed',
minimum=-1,
maximum=1000,
step=1,
randomize=True)
n_prompt = gr.Textbox(
label='Negative Prompt',
value="low quality, blurry, low-resolution, noisy, unsharp, weird textures"
)
task = gr.Textbox(value="dn", visible=False)
sr_ratio = gr.Number(value=1, visible=False)
# input_video.change(
# fn = update_prompt,
# inputs = [input_video],
# outputs = [prompt],
# queue = False)
run_button.click(fn = DiffBIR_restore,
inputs = [input_video,
prompt,
sr_ratio,
n_frames,
n_steps,
guidance_scale,
seed,
n_prompt,
task
],
outputs = [output_video]
)
gr.Examples(
examples=get_example("dn"),
label='Examples',
inputs=[input_video],
outputs=[output_video],
examples_per_page=7
)
demo.queue()
demo.launch(share=True)