File size: 6,252 Bytes
1c337e7 37922b2 1c337e7 37922b2 4cb6889 37922b2 aac916d 37922b2 1c337e7 6081079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gradio as gr
import torch
from PIL import Image
import numpy as np
from engine import SegmentAnythingModel, StableDiffusionInpaintingPipeline
from utils import show_anns, create_image_grid
import matplotlib.pyplot as plt
import PIL
import requests
import matplotlib
matplotlib.use('Agg') # Use Agg backend
# Check for CUDA availability
if not torch.cuda.is_available():
# If CUDA isn't available, create the Gradio interface but disable all elements and show a warning
with gr.Blocks() as demo:
gr.Markdown("# Segment Anything + Stable Diffusion Inpainting")
# Show warning message for users with a link to the GitHub repository
gr.Markdown("**CUDA is not available.** Please run it on Google Colab. You can find the Colab here: [Colab Link](https://github.com/SanshruthR/Stable-Diffusion-Inpainting_with_SAM)")
# Step 1: Segment Image Tab
with gr.Tab("Step 1: Segment Image"):
with gr.Row():
input_image = gr.Image(label="Input Image", interactive=False)
mask_output = gr.Plot(label="Available Masks")
segment_btn = gr.Button("Generate Masks", interactive=False)
# Step 2: Inpainting Tab
with gr.Tab("Step 2: Inpaint"):
with gr.Row():
with gr.Column():
mask_index = gr.Slider(minimum=0, maximum=20, step=1,
label="Mask Index (select based on mask numbers from Step 1)", interactive=False)
prompt1 = gr.Textbox(label="Prompt 1", placeholder="Enter first inpainting prompt", interactive=False)
prompt2 = gr.Textbox(label="Prompt 2", placeholder="Enter second inpainting prompt", interactive=False)
prompt3 = gr.Textbox(label="Prompt 3", placeholder="Enter third inpainting prompt", interactive=False)
prompt4 = gr.Textbox(label="Prompt 4", placeholder="Enter fourth inpainting prompt", interactive=False)
inpaint_output = gr.Plot(label="Inpainting Results")
inpaint_btn = gr.Button("Generate Inpainting", interactive=False)
demo.launch(share=True, debug=True)
exit() # Exit the program if CUDA is not available
# Download SAM checkpoint
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
response = requests.get(url)
with open("sam_vit_h_4b8939.pth", "wb") as file:
file.write(response.content)
# Initialize models
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda" # Default device
sam_model = SegmentAnythingModel(sam_checkpoint, model_type, device)
model_dir = "stabilityai/stable-diffusion-2-inpainting"
sd_pipeline = StableDiffusionInpaintingPipeline(model_dir)
# Global variable to store masks
current_masks = None
current_image = None
def segment_image(image):
global current_masks, current_image
current_image = image
# Convert to numpy array
image_array = np.array(image)
# Generate masks
current_masks = sam_model.generate_masks(image_array)
# Create visualization of masks
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
# Display the original image first
ax.imshow(sam_model.preprocess_image(image))
# Overlay masks
show_anns(current_masks, ax)
ax.axis('off')
plt.tight_layout()
return fig
def inpaint_image(mask_index, prompt1, prompt2, prompt3, prompt4):
global current_masks, current_image
if current_masks is None or current_image is None:
return None
# Get selected mask
segmentation_mask = current_masks[mask_index]['segmentation']
stable_diffusion_mask = PIL.Image.fromarray((segmentation_mask * 255).astype(np.uint8))
# Generate inpainted images
prompts = [p for p in [prompt1, prompt2, prompt3, prompt4] if p.strip()]
generator = torch.Generator(device="cuda").manual_seed(42) # Fixed seed for consistency
encoded_images = []
for prompt in prompts:
img = sd_pipeline.inpaint(
prompt=prompt,
image=Image.fromarray(np.array(current_image)),
mask_image=stable_diffusion_mask,
guidance_scale=7.5, # Lower guidance scale for more creative results
num_inference_steps=50, # Good balance between quality and speed
generator=generator
)
encoded_images.append(img)
# Create result grid
result_grid = create_image_grid(Image.fromarray(np.array(current_image)),
encoded_images,
prompts,
2, 3)
return result_grid
# Create Gradio interface with two tabs
with gr.Blocks() as demo:
gr.Markdown("# Segment Anything + Stable Diffusion Inpainting")
with gr.Tab("Step 1: Segment Image"):
with gr.Row():
input_image = gr.Image(label="Input Image")
mask_output = gr.Plot(label="Available Masks")
segment_btn = gr.Button("Generate Masks")
segment_btn.click(fn=segment_image, inputs=[input_image], outputs=[mask_output])
with gr.Tab("Step 2: Inpaint"):
with gr.Row():
with gr.Column():
mask_index = gr.Slider(minimum=0, maximum=20, step=1,
label="Mask Index (select based on mask numbers from Step 1)")
prompt1 = gr.Textbox(label="Prompt 1", placeholder="Enter first inpainting prompt")
prompt2 = gr.Textbox(label="Prompt 2", placeholder="Enter second inpainting prompt")
prompt3 = gr.Textbox(label="Prompt 3", placeholder="Enter third inpainting prompt")
prompt4 = gr.Textbox(label="Prompt 4", placeholder="Enter fourth inpainting prompt")
inpaint_output = gr.Plot(label="Inpainting Results")
inpaint_btn = gr.Button("Generate Inpainting")
inpaint_btn.click(fn=inpaint_image,
inputs=[mask_index, prompt1, prompt2, prompt3, prompt4],
outputs=[inpaint_output])
if __name__ == "__main__":
demo.launch(share=True, debug=True, ssr_mode=False)
|