import gradio as gr from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms from PIL import Image import numpy as np import os import uuid import moviepy.editor as mp # Set the torch precision torch.set_float32_matmul_precision("medium") # Set the device to use GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" # Load the pre-trained image segmentation model birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to(device) # Define the image transformation pipeline transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) def process(image, color): image_size = image.size input_images = transform_image(image).unsqueeze(0).to(device) # Predict the segmentation mask with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() # Convert prediction to PIL image pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) # Convert hex color to RGB color_rgb = tuple(int(color[i:i+2], 16) for i in (1, 3, 5)) background = Image.new("RGBA", image_size, color_rgb + (255,)) # Composite the image onto the background using the mask image = Image.composite(image.convert("RGBA"), background, mask) return image def fn(vid, color="#00FF00", fps=0): try: # Load the video using moviepy video = mp.VideoFileClip(vid) # Extract original FPS if fps is set to 0 if fps == 0: fps = video.fps # Extract audio from the video audio = video.audio # Extract frames at the specified FPS frames = video.iter_frames(fps=fps) processed_frames = [] yield gr.update(visible=True), gr.update(visible=False) # Process each frame for background removal for frame in frames: pil_image = Image.fromarray(frame) processed_image = process(pil_image, color) processed_frames.append(np.array(processed_image)) yield processed_image, None # Create a new video from the processed frames processed_video = mp.ImageSequenceClip(processed_frames, fps=fps) # Add the original audio back to the processed video processed_video = processed_video.set_audio(audio) # Save the processed video to a temporary file temp_dir = "temp" os.makedirs(temp_dir, exist_ok=True) unique_filename = str(uuid.uuid4()) + ".mp4" temp_filepath = os.path.join(temp_dir, unique_filename) processed_video.write_videofile(temp_filepath, codec="libx264") yield gr.update(visible=False), gr.update(visible=True) yield None, temp_filepath # Return the final video path here except Exception as e: print(f"Error: {e}") yield gr.update(visible=False), gr.update(visible=True) yield None, f"Error processing video: {e}" with gr.Blocks(theme=gr.themes.Ocean()) as demo: gr.Markdown("# Video Background Remover & Changer\n### You can replace the video background with any solid color.") with gr.Row(): in_video = gr.Video(label="Input Video", interactive=True) stream_image = gr.Image(label="Streaming Output", visible=False) out_video = gr.Video(label="Final Output Video") submit_button = gr.Button("Change Background", interactive=True) with gr.Row(): fps_slider = gr.Slider( minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 will inherit the original fps value)", interactive=True ) color_picker = gr.ColorPicker(label="Background Color", value="#00FF00", interactive=True) submit_button.click( fn, inputs=[in_video, color_picker, fps_slider], outputs=[stream_image, out_video], ) if __name__ == "__main__": demo.launch(show_error=True)