Object_Remove / app.py
ifmain's picture
Update app.py
319cffc verified
history blame
4.43 kB
# Standard Libraries
import time
from io import BytesIO
import base64
# Data Handling and Image Processing
import numpy as np
from PIL import Image
# Machine Learning and AI Models
import torch
from transformers import pipeline
from diffusers import AutoPipelineForInpainting
from ultralytics import YOLO
# Text and Data Manipulation
import difflib
# UI and Application Framework
import gradio as gr
import spaces
# Constants
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
yoloModel = YOLO('yolov8x-seg.pt')
sdxl = AutoPipelineForInpainting.from_pretrained(
image_captioner = pipeline("image-to-text", model="Abdou/vit-swin-base-224-gpt2-image-captioning", device=DEVICE)
def image_to_base64(image: Image.Image):
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def get_most_similar_string(target_string, string_array):
differ = difflib.Differ()
best_match = string_array[0]
best_match_ratio = 0
for candidate_string in string_array:
similarity_ratio = difflib.SequenceMatcher(None, target_string, candidate_string).ratio()
if similarity_ratio > best_match_ratio:
best_match = candidate_string
best_match_ratio = similarity_ratio
return best_match
# Yolo
def getClasses(model, img1):
results = model([img1])
out = []
for r in results:
im_array = r.plot()
return r, im_array[..., ::-1], results
def getMasks(out):
allout = {}
class_masks = {}
for a in out:
class_name = a['name']
mask = a['img']
if class_name in class_masks:
class_masks[class_name] = Image.fromarray(
np.maximum(np.array(class_masks[class_name]), np.array(mask))
class_masks[class_name] = mask
for class_name, mask in class_masks.items():
allout[class_name] = mask
return allout
def joinClasses(classes):
i = 0
out = []
for r in classes:
masks = r.masks
name0 = r.names[int(r.boxes.cls.cpu().numpy()[0])]
mask1 = masks[0]
mask = mask1.data[0].cpu().numpy()
polygon = mask1.xy[0]
# Normalize the mask values to 0-255 if needed
mask_normalized = ((mask - mask.min()) * (255 / (mask.max() - mask.min()))).astype(np.uint8)
mask_img = Image.fromarray(mask_normalized, "L")
out.append({'name': name0, 'img': mask_img})
i += 1
allMask = getMasks(out)
return allMask
def getSegments(yoloModel, img1):
classes, image, results1 = getClasses(yoloModel, img1)
allMask = joinClasses(classes)
return allMask
# Gradio UI
def captionMaker(base64_img):
return image_captioner(base64_img)[0]['generated_text']
def getDescript(image_captioner, img1):
base64_img = image_to_base64(img1)
caption = captionMaker(base64_img)
return caption
def rmGPT(caption, remove_class):
arstr = caption.split(' ')
popular = get_most_similar_string(remove_class, arstr)
ind = arstr.index(popular)
new = []
for i in range(len(arstr)):
if i not in list(range(ind - 2, ind + 3)):
return ' '.join(new)
def ChangeOBJ(sdxl_m, img1, response, mask1):
size = img1.size
image = sdxl_m(prompt=response, image=img1, mask_image=mask1).images[0]
return image.resize((size[0], size[1]))
def full_pipeline(image, target):
img1 = Image.fromarray(image.astype('uint8'), 'RGB')
allMask = getSegments(yoloModel, img1)
tartget_to_remove = get_most_similar_string(target, list(allMask.keys()))
caption = getDescript(image_captioner, img1)
response = rmGPT(caption, tartget_to_remove)
mask1 = allMask[tartget_to_remove]
remimg = ChangeOBJ(sdxl, img1, response, mask1)
return remimg, caption, response
iface = gr.Interface(
gr.Image(label="Upload Image"),
gr.Textbox(label="What to delete?"),
gr.Image(label="Result Image", type="numpy"),