|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
from engine import SegmentAnythingModel, StableDiffusionInpaintingPipeline |
|
from utils import show_anns, create_image_grid |
|
import matplotlib.pyplot as plt |
|
import PIL |
|
import requests |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Segment Anything + Stable Diffusion Inpainting") |
|
|
|
|
|
gr.Markdown("**CUDA is not available.** Please run it on Google Colab. You can find the Colab here: [Colab Link](https://github.com/SanshruthR/Stable-Diffusion-Inpainting_with_SAM)") |
|
|
|
|
|
with gr.Tab("Step 1: Segment Image"): |
|
with gr.Row(): |
|
input_image = gr.Image(label="Input Image", interactive=False) |
|
mask_output = gr.Plot(label="Available Masks") |
|
segment_btn = gr.Button("Generate Masks", interactive=False) |
|
|
|
|
|
with gr.Tab("Step 2: Inpaint"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
mask_index = gr.Slider(minimum=0, maximum=20, step=1, |
|
label="Mask Index (select based on mask numbers from Step 1)", interactive=False) |
|
prompt1 = gr.Textbox(label="Prompt 1", placeholder="Enter first inpainting prompt", interactive=False) |
|
prompt2 = gr.Textbox(label="Prompt 2", placeholder="Enter second inpainting prompt", interactive=False) |
|
prompt3 = gr.Textbox(label="Prompt 3", placeholder="Enter third inpainting prompt", interactive=False) |
|
prompt4 = gr.Textbox(label="Prompt 4", placeholder="Enter fourth inpainting prompt", interactive=False) |
|
inpaint_output = gr.Plot(label="Inpainting Results", interactive=False) |
|
inpaint_btn = gr.Button("Generate Inpainting", interactive=False) |
|
|
|
demo.launch(share=True, debug=True) |
|
exit() |
|
|
|
|
|
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
|
response = requests.get(url) |
|
|
|
with open("sam_vit_h_4b8939.pth", "wb") as file: |
|
file.write(response.content) |
|
|
|
|
|
sam_checkpoint = "sam_vit_h_4b8939.pth" |
|
model_type = "vit_h" |
|
device = "cuda" |
|
sam_model = SegmentAnythingModel(sam_checkpoint, model_type, device) |
|
|
|
model_dir = "stabilityai/stable-diffusion-2-inpainting" |
|
sd_pipeline = StableDiffusionInpaintingPipeline(model_dir) |
|
|
|
|
|
current_masks = None |
|
current_image = None |
|
|
|
def segment_image(image): |
|
global current_masks, current_image |
|
current_image = image |
|
|
|
|
|
image_array = np.array(image) |
|
|
|
|
|
current_masks = sam_model.generate_masks(image_array) |
|
|
|
|
|
fig = plt.figure(figsize=(10, 10)) |
|
ax = fig.add_subplot(1, 1, 1) |
|
|
|
|
|
ax.imshow(sam_model.preprocess_image(image)) |
|
|
|
|
|
show_anns(current_masks, ax) |
|
|
|
ax.axis('off') |
|
plt.tight_layout() |
|
|
|
return fig |
|
|
|
def inpaint_image(mask_index, prompt1, prompt2, prompt3, prompt4): |
|
global current_masks, current_image |
|
|
|
if current_masks is None or current_image is None: |
|
return None |
|
|
|
|
|
segmentation_mask = current_masks[mask_index]['segmentation'] |
|
stable_diffusion_mask = PIL.Image.fromarray((segmentation_mask * 255).astype(np.uint8)) |
|
|
|
|
|
prompts = [p for p in [prompt1, prompt2, prompt3, prompt4] if p.strip()] |
|
generator = torch.Generator(device="cuda").manual_seed(42) |
|
|
|
encoded_images = [] |
|
for prompt in prompts: |
|
img = sd_pipeline.inpaint( |
|
prompt=prompt, |
|
image=Image.fromarray(np.array(current_image)), |
|
mask_image=stable_diffusion_mask, |
|
guidance_scale=7.5, |
|
num_inference_steps=50, |
|
generator=generator |
|
) |
|
encoded_images.append(img) |
|
|
|
|
|
result_grid = create_image_grid(Image.fromarray(np.array(current_image)), |
|
encoded_images, |
|
prompts, |
|
2, 3) |
|
|
|
return result_grid |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Segment Anything + Stable Diffusion Inpainting") |
|
|
|
with gr.Tab("Step 1: Segment Image"): |
|
with gr.Row(): |
|
input_image = gr.Image(label="Input Image") |
|
mask_output = gr.Plot(label="Available Masks") |
|
segment_btn = gr.Button("Generate Masks") |
|
segment_btn.click(fn=segment_image, inputs=[input_image], outputs=[mask_output]) |
|
|
|
with gr.Tab("Step 2: Inpaint"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
mask_index = gr.Slider(minimum=0, maximum=20, step=1, |
|
label="Mask Index (select based on mask numbers from Step 1)") |
|
prompt1 = gr.Textbox(label="Prompt 1", placeholder="Enter first inpainting prompt") |
|
prompt2 = gr.Textbox(label="Prompt 2", placeholder="Enter second inpainting prompt") |
|
prompt3 = gr.Textbox(label="Prompt 3", placeholder="Enter third inpainting prompt") |
|
prompt4 = gr.Textbox(label="Prompt 4", placeholder="Enter fourth inpainting prompt") |
|
inpaint_output = gr.Plot(label="Inpainting Results") |
|
inpaint_btn = gr.Button("Generate Inpainting") |
|
inpaint_btn.click(fn=inpaint_image, |
|
inputs=[mask_index, prompt1, prompt2, prompt3, prompt4], |
|
outputs=[inpaint_output]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True, debug=True) |
|
|