File size: 4,374 Bytes
17aaf2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7b7ab1
a567a68
17aaf2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64712db
17aaf2d
 
 
 
 
a967a05
17aaf2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import time

from diffusers import AutoPipelineForInpainting
from transformers import pipeline
from ultralytics import YOLO
from PIL import Image
import numpy as np
import torch
import base64
from io import BytesIO
import gradio as gr
from gradio import components
import difflib

# Constants
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load

def image_to_base64(image: Image.Image):
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")

def get_most_similar_string(target_string, string_array):
    differ = difflib.Differ()
    best_match = string_array[0]
    best_match_ratio = 0
    for candidate_string in string_array:
        similarity_ratio = difflib.SequenceMatcher(None, target_string, candidate_string).ratio()
        if similarity_ratio > best_match_ratio:
            best_match = candidate_string
            best_match_ratio = similarity_ratio
    
    return best_match

def loadModels():

    yoloModel=YOLO('yolov8x-seg.pt')
    pipe =AutoPipelineForInpainting.from_pretrained(
        "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
         torch_dtype=torch.float32
        ).to(DEVICE)
    image_captioner = pipeline("image-to-text", model="Abdou/vit-swin-base-224-gpt2-image-captioning", device=DEVICE)
    #return gpt_model, gpt_tokenizer, gpt_params,yoloModel,pipe,image_captioner
    return yoloModel,pipe,image_captioner

# Yolo

def getClasses(model,img1):
    results = model([img1])
    out=[]
    for r in results:
        #im_array = r.plot(boxes=False,labels=False)  # plot a BGR numpy array of predictions
        im_array = r.plot()
        out.append(r)

    return r,im_array[..., ::-1],results

def getMasks(out):
    allout={}
    class_masks = {}
    for a in out:
        class_name = a['name']
        mask = a['img']
        if class_name in class_masks:
            class_masks[class_name] = Image.fromarray(
                np.maximum(np.array(class_masks[class_name]), np.array(mask))
            )
        else:
            class_masks[class_name] = mask
    for class_name, mask in class_masks.items():
        allout[class_name]=mask
    return allout

def joinClasses(classes):
    i=0
    out=[]
    for r in classes:
        masks=r.masks
        name0=r.names[int(r.boxes.cls.cpu().numpy()[0])]

        mask1 = masks[0]
        mask = mask1.data[0].cpu().numpy()
        polygon = mask1.xy[0]
        # Normalize the mask values to 0-255 if needed
        mask_normalized = ((mask - mask.min()) * (255 / (mask.max() - mask.min()))).astype(np.uint8)
        mask_img = Image.fromarray(mask_normalized, "L")
        out.append({'name':name0,'img':mask_img})
        i+=1

    allMask=getMasks(out)
    return allMask

def getSegments(yoloModel,img1):
    classes,image,results1=getClasses(yoloModel,img1)
    allMask=joinClasses(classes)
    return allMask

# Gradio UI

def getDescript(image_captioner,img1):
    base64_img = image_to_base64(img1)
    caption = image_captioner(base64_img)[0]['generated_text']
    return caption

def rmGPT(caption,remove_class):
    arstr=caption.split(' ')
    popular=get_most_similar_string(remove_class,arstr)
    ind=arstr.index(popular)
    new=[]
    for i in range(len(arstr)):
        if i not in list(range(ind-2,ind+3)):
            new.append(arstr[i])
    return ' '.join(new)

# SDXL

def ChangeOBJ(sdxl_m,img1,response,mask1):
    size = img1.size
    image = sdxl_m(prompt=response, image=img1, mask_image=mask1).images[0]
    return image.resize((size[0], size[1]))



yoloModel,sdxl,image_captioner=loadModels()

def full_pipeline(image, target):
    img1 = Image.fromarray(image.astype('uint8'), 'RGB')
    allMask=getSegments(yoloModel,img1)
    tartget_to_remove=get_most_similar_string(target,list(allMask.keys()))
    caption=getDescript(image_captioner,img1)

    response=rmGPT(caption,tartget_to_remove)
    mask1=allMask[tartget_to_remove]

    remimg=ChangeOBJ(sdxl,img1,response,mask1)

    return remimg,caption,response



iface = gr.Interface(
    fn=full_pipeline, 
    inputs=[
        gr.Image(label="Upload Image"),
        gr.Textbox(label="What to delete?"),
    ], 
    outputs=[
        gr.Image(label="Result Image", type="numpy"),
        gr.Textbox(label="Caption"),
        gr.Textbox(label="Message"),
    ],
    live=False
)


iface.launch()