Sanshruth's picture
Update app.py
4cb6889 verified
raw
history blame
6.26 kB
import gradio as gr
import torch
from PIL import Image
import numpy as np
from engine import SegmentAnythingModel, StableDiffusionInpaintingPipeline
from utils import show_anns, create_image_grid
import matplotlib.pyplot as plt
import PIL
import requests
import matplotlib
matplotlib.use('Agg') # Use Agg backend
# Check for CUDA availability
if not torch.cuda.is_available():
# If CUDA isn't available, create the Gradio interface but disable all elements and show a warning
with gr.Blocks() as demo:
gr.Markdown("# Segment Anything + Stable Diffusion Inpainting")
# Show warning message for users with a link to the GitHub repository
gr.Markdown("**CUDA is not available.** Please run it on Google Colab. You can find the Colab here: [Colab Link](https://github.com/SanshruthR/Stable-Diffusion-Inpainting_with_SAM)")
# Step 1: Segment Image Tab
with gr.Tab("Step 1: Segment Image"):
with gr.Row():
input_image = gr.Image(label="Input Image", interactive=False)
mask_output = gr.Plot(label="Available Masks")
segment_btn = gr.Button("Generate Masks", interactive=False)
# Step 2: Inpainting Tab
with gr.Tab("Step 2: Inpaint"):
with gr.Row():
with gr.Column():
mask_index = gr.Slider(minimum=0, maximum=20, step=1,
label="Mask Index (select based on mask numbers from Step 1)", interactive=False)
prompt1 = gr.Textbox(label="Prompt 1", placeholder="Enter first inpainting prompt", interactive=False)
prompt2 = gr.Textbox(label="Prompt 2", placeholder="Enter second inpainting prompt", interactive=False)
prompt3 = gr.Textbox(label="Prompt 3", placeholder="Enter third inpainting prompt", interactive=False)
prompt4 = gr.Textbox(label="Prompt 4", placeholder="Enter fourth inpainting prompt", interactive=False)
inpaint_output = gr.Plot(label="Inpainting Results", interactive=False)
inpaint_btn = gr.Button("Generate Inpainting", interactive=False)
demo.launch(share=True, debug=True)
exit() # Exit the program if CUDA is not available
# Download SAM checkpoint
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
response = requests.get(url)
with open("sam_vit_h_4b8939.pth", "wb") as file:
file.write(response.content)
# Initialize models
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda" # Default device
sam_model = SegmentAnythingModel(sam_checkpoint, model_type, device)
model_dir = "stabilityai/stable-diffusion-2-inpainting"
sd_pipeline = StableDiffusionInpaintingPipeline(model_dir)
# Global variable to store masks
current_masks = None
current_image = None
def segment_image(image):
global current_masks, current_image
current_image = image
# Convert to numpy array
image_array = np.array(image)
# Generate masks
current_masks = sam_model.generate_masks(image_array)
# Create visualization of masks
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
# Display the original image first
ax.imshow(sam_model.preprocess_image(image))
# Overlay masks
show_anns(current_masks, ax)
ax.axis('off')
plt.tight_layout()
return fig
def inpaint_image(mask_index, prompt1, prompt2, prompt3, prompt4):
global current_masks, current_image
if current_masks is None or current_image is None:
return None
# Get selected mask
segmentation_mask = current_masks[mask_index]['segmentation']
stable_diffusion_mask = PIL.Image.fromarray((segmentation_mask * 255).astype(np.uint8))
# Generate inpainted images
prompts = [p for p in [prompt1, prompt2, prompt3, prompt4] if p.strip()]
generator = torch.Generator(device="cuda").manual_seed(42) # Fixed seed for consistency
encoded_images = []
for prompt in prompts:
img = sd_pipeline.inpaint(
prompt=prompt,
image=Image.fromarray(np.array(current_image)),
mask_image=stable_diffusion_mask,
guidance_scale=7.5, # Lower guidance scale for more creative results
num_inference_steps=50, # Good balance between quality and speed
generator=generator
)
encoded_images.append(img)
# Create result grid
result_grid = create_image_grid(Image.fromarray(np.array(current_image)),
encoded_images,
prompts,
2, 3)
return result_grid
# Create Gradio interface with two tabs
with gr.Blocks() as demo:
gr.Markdown("# Segment Anything + Stable Diffusion Inpainting")
with gr.Tab("Step 1: Segment Image"):
with gr.Row():
input_image = gr.Image(label="Input Image")
mask_output = gr.Plot(label="Available Masks")
segment_btn = gr.Button("Generate Masks")
segment_btn.click(fn=segment_image, inputs=[input_image], outputs=[mask_output])
with gr.Tab("Step 2: Inpaint"):
with gr.Row():
with gr.Column():
mask_index = gr.Slider(minimum=0, maximum=20, step=1,
label="Mask Index (select based on mask numbers from Step 1)")
prompt1 = gr.Textbox(label="Prompt 1", placeholder="Enter first inpainting prompt")
prompt2 = gr.Textbox(label="Prompt 2", placeholder="Enter second inpainting prompt")
prompt3 = gr.Textbox(label="Prompt 3", placeholder="Enter third inpainting prompt")
prompt4 = gr.Textbox(label="Prompt 4", placeholder="Enter fourth inpainting prompt")
inpaint_output = gr.Plot(label="Inpainting Results")
inpaint_btn = gr.Button("Generate Inpainting")
inpaint_btn.click(fn=inpaint_image,
inputs=[mask_index, prompt1, prompt2, prompt3, prompt4],
outputs=[inpaint_output])
if __name__ == "__main__":
demo.launch(share=True, debug=True)