Spaces:
Running
on
Zero
Running
on
Zero
# Standard Libraries | |
import time | |
from io import BytesIO | |
import base64 | |
# Data Handling and Image Processing | |
import numpy as np | |
from PIL import Image | |
# Machine Learning and AI Models | |
import torch | |
from transformers import pipeline | |
from diffusers import AutoPipelineForInpainting | |
from ultralytics import YOLO | |
# Text and Data Manipulation | |
import difflib | |
# UI and Application Framework | |
import gradio as gr | |
import spaces | |
# Constants | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# Load | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
yoloModel = YOLO('yolov8x-seg.pt') | |
sdxl = AutoPipelineForInpainting.from_pretrained( | |
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
torch_dtype=torch.float32 | |
).to(DEVICE) | |
image_captioner = pipeline("image-to-text", model="Abdou/vit-swin-base-224-gpt2-image-captioning", device=DEVICE) | |
def image_to_base64(image: Image.Image): | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
def get_most_similar_string(target_string, string_array): | |
differ = difflib.Differ() | |
best_match = string_array[0] | |
best_match_ratio = 0 | |
for candidate_string in string_array: | |
similarity_ratio = difflib.SequenceMatcher(None, target_string, candidate_string).ratio() | |
if similarity_ratio > best_match_ratio: | |
best_match = candidate_string | |
best_match_ratio = similarity_ratio | |
return best_match | |
# Yolo | |
def getClasses(model, img1): | |
results = model([img1]) | |
out = [] | |
for r in results: | |
im_array = r.plot() | |
out.append(r) | |
return r, im_array[..., ::-1], results | |
def getMasks(out): | |
allout = {} | |
class_masks = {} | |
for a in out: | |
class_name = a['name'] | |
mask = a['img'] | |
if class_name in class_masks: | |
class_masks[class_name] = Image.fromarray( | |
np.maximum(np.array(class_masks[class_name]), np.array(mask)) | |
) | |
else: | |
class_masks[class_name] = mask | |
for class_name, mask in class_masks.items(): | |
allout[class_name] = mask | |
return allout | |
def joinClasses(classes): | |
i = 0 | |
out = [] | |
for r in classes: | |
masks = r.masks | |
name0 = r.names[int(r.boxes.cls.cpu().numpy()[0])] | |
mask1 = masks[0] | |
mask = mask1.data[0].cpu().numpy() | |
polygon = mask1.xy[0] | |
# Normalize the mask values to 0-255 if needed | |
mask_normalized = ((mask - mask.min()) * (255 / (mask.max() - mask.min()))).astype(np.uint8) | |
mask_img = Image.fromarray(mask_normalized, "L") | |
out.append({'name': name0, 'img': mask_img}) | |
i += 1 | |
allMask = getMasks(out) | |
return allMask | |
def getSegments(yoloModel, img1): | |
classes, image, results1 = getClasses(yoloModel, img1) | |
allMask = joinClasses(classes) | |
return allMask | |
# Gradio UI | |
def captionMaker(base64_img): | |
return image_captioner(base64_img)[0]['generated_text'] | |
def getDescript(image_captioner, img1): | |
base64_img = image_to_base64(img1) | |
caption = captionMaker(base64_img) | |
return caption | |
def rmGPT(caption, remove_class): | |
arstr = caption.split(' ') | |
popular = get_most_similar_string(remove_class, arstr) | |
ind = arstr.index(popular) | |
new = [] | |
for i in range(len(arstr)): | |
if i not in list(range(ind - 2, ind + 3)): | |
new.append(arstr[i]) | |
return ' '.join(new) | |
def ChangeOBJ(sdxl_m, img1, response, mask1): | |
size = img1.size | |
image = sdxl_m(prompt=response, image=img1, mask_image=mask1).images[0] | |
return image.resize((size[0], size[1])) | |
def full_pipeline(image, target): | |
img1 = Image.fromarray(image.astype('uint8'), 'RGB') | |
allMask = getSegments(yoloModel, img1) | |
tartget_to_remove = get_most_similar_string(target, list(allMask.keys())) | |
caption = getDescript(image_captioner, img1) | |
response = rmGPT(caption, tartget_to_remove) | |
mask1 = allMask[tartget_to_remove] | |
remimg = ChangeOBJ(sdxl, img1, response, mask1) | |
return remimg, caption, response | |
iface = gr.Interface( | |
fn=full_pipeline, | |
inputs=[ | |
gr.Image(label="Upload Image"), | |
gr.Textbox(label="What to delete?"), | |
], | |
outputs=[ | |
gr.Image(label="Result Image", type="numpy"), | |
gr.Textbox(label="Caption"), | |
gr.Textbox(label="Message"), | |
], | |
live=False | |
) | |
iface.launch() | |