import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation from PIL import Image import requests from io import BytesIO import gradio as gr # Set up CUDA if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.set_float32_matmul_precision("high") # Load the model birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to(device) # Define image transformations transform_image = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def load_img(image_path_or_url): if image_path_or_url.startswith('http'): response = requests.get(image_path_or_url) img = Image.open(BytesIO(response.content)) else: img = Image.open(image_path_or_url) return img.convert("RGB") def process(image): image_size = image.size input_images = transform_image(image).unsqueeze(0).to(device) with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) # Create a new image with transparency transparent_image = Image.new("RGBA", image.size) transparent_image.paste(image, (0, 0)) transparent_image.putalpha(mask) # Apply mask to the new image return transparent_image # Return the new transparent image def remove_background_gradio(image): processed_img = process(image) return processed_img # Create the Gradio interface with drag-and-drop and paste functionality demo = gr.Interface( fn=remove_background_gradio, inputs=gr.Image(type="pil"), # Remove 'source' argument outputs=gr.Image(type="pil"), title="RemoveBG", description="Upload an image to remove its background (drag-and-drop or upload)." ) demo.launch(share=True) # Launch the interface and get a shareable link