sainkan's picture
Update app.py
7b61a40 verified
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import os
import uuid
import moviepy.editor as mp
# Set the torch precision
torch.set_float32_matmul_precision("medium")
# Set the device to use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the pre-trained image segmentation model
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to(device)
# Define the image transformation pipeline
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def process(image, color):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to(device)
# Predict the segmentation mask
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# Convert prediction to PIL image
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# Convert hex color to RGB
color_rgb = tuple(int(color[i:i+2], 16) for i in (1, 3, 5))
background = Image.new("RGBA", image_size, color_rgb + (255,))
# Composite the image onto the background using the mask
image = Image.composite(image.convert("RGBA"), background, mask)
return image
def fn(vid, color="#00FF00", fps=0):
try:
# Load the video using moviepy
video = mp.VideoFileClip(vid)
# Extract original FPS if fps is set to 0
if fps == 0:
fps = video.fps
# Extract audio from the video
audio = video.audio
# Extract frames at the specified FPS
frames = video.iter_frames(fps=fps)
processed_frames = []
yield gr.update(visible=True), gr.update(visible=False)
# Process each frame for background removal
for frame in frames:
pil_image = Image.fromarray(frame)
processed_image = process(pil_image, color)
processed_frames.append(np.array(processed_image))
yield processed_image, None
# Create a new video from the processed frames
processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
# Add the original audio back to the processed video
processed_video = processed_video.set_audio(audio)
# Save the processed video to a temporary file
temp_dir = "temp"
os.makedirs(temp_dir, exist_ok=True)
unique_filename = str(uuid.uuid4()) + ".mp4"
temp_filepath = os.path.join(temp_dir, unique_filename)
processed_video.write_videofile(temp_filepath, codec="libx264")
yield gr.update(visible=False), gr.update(visible=True)
yield None, temp_filepath # Return the final video path here
except Exception as e:
print(f"Error: {e}")
yield gr.update(visible=False), gr.update(visible=True)
yield None, f"Error processing video: {e}"
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
gr.Markdown("# Video Background Remover & Changer\n### You can replace the video background with any solid color.")
with gr.Row():
in_video = gr.Video(label="Input Video", interactive=True)
stream_image = gr.Image(label="Streaming Output", visible=False)
out_video = gr.Video(label="Final Output Video")
submit_button = gr.Button("Change Background", interactive=True)
with gr.Row():
fps_slider = gr.Slider(
minimum=0,
maximum=60,
step=1,
value=0,
label="Output FPS (0 will inherit the original fps value)",
interactive=True
)
color_picker = gr.ColorPicker(label="Background Color", value="#00FF00", interactive=True)
submit_button.click(
fn,
inputs=[in_video, color_picker, fps_slider],
outputs=[stream_image, out_video],
)
if __name__ == "__main__":
demo.launch(show_error=True)