File size: 4,429 Bytes
319cffc
 
17aaf2d
319cffc
 
 
 
 
 
 
 
 
 
 
 
 
17aaf2d
319cffc
 
 
339ab7c
26ef7d6
319cffc
 
 
 
 
 
 
 
 
 
 
 
 
 
17aaf2d
 
 
 
 
319cffc
17aaf2d
319cffc
17aaf2d
 
 
 
 
 
 
319cffc
17aaf2d
 
319cffc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26ef7d6
319cffc
 
 
 
 
 
 
 
17aaf2d
 
319cffc
 
 
 
 
 
 
 
 
17aaf2d
 
319cffc
 
 
 
 
 
17aaf2d
e09ffb3
319cffc
 
 
 
 
 
 
 
 
 
 
 
17aaf2d
 
319cffc
17aaf2d
319cffc
97448a3
319cffc
17aaf2d
319cffc
97448a3
 
17aaf2d
97448a3
17aaf2d
 
319cffc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Standard Libraries
import time
from io import BytesIO
import base64

# Data Handling and Image Processing
import numpy as np
from PIL import Image

# Machine Learning and AI Models
import torch
from transformers import pipeline
from diffusers import AutoPipelineForInpainting
from ultralytics import YOLO

# Text and Data Manipulation
import difflib

# UI and Application Framework
import gradio as gr
import spaces


# Constants
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
yoloModel = YOLO('yolov8x-seg.pt')
sdxl = AutoPipelineForInpainting.from_pretrained(
    "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
    torch_dtype=torch.float32
).to(DEVICE)
image_captioner = pipeline("image-to-text", model="Abdou/vit-swin-base-224-gpt2-image-captioning", device=DEVICE)


def image_to_base64(image: Image.Image):
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def get_most_similar_string(target_string, string_array):
    differ = difflib.Differ()
    best_match = string_array[0]
    best_match_ratio = 0
    for candidate_string in string_array:
        similarity_ratio = difflib.SequenceMatcher(None, target_string, candidate_string).ratio()
        if similarity_ratio > best_match_ratio:
            best_match = candidate_string
            best_match_ratio = similarity_ratio

    return best_match


# Yolo
@spaces.GPU
def getClasses(model, img1):
    results = model([img1])
    out = []
    for r in results:
        im_array = r.plot()
        out.append(r)

    return r, im_array[..., ::-1], results


def getMasks(out):
    allout = {}
    class_masks = {}
    for a in out:
        class_name = a['name']
        mask = a['img']
        if class_name in class_masks:
            class_masks[class_name] = Image.fromarray(
                np.maximum(np.array(class_masks[class_name]), np.array(mask))
            )
        else:
            class_masks[class_name] = mask
    for class_name, mask in class_masks.items():
        allout[class_name] = mask
    return allout


def joinClasses(classes):
    i = 0
    out = []
    for r in classes:
        masks = r.masks
        name0 = r.names[int(r.boxes.cls.cpu().numpy()[0])]

        mask1 = masks[0]
        mask = mask1.data[0].cpu().numpy()
        polygon = mask1.xy[0]
        # Normalize the mask values to 0-255 if needed
        mask_normalized = ((mask - mask.min()) * (255 / (mask.max() - mask.min()))).astype(np.uint8)
        mask_img = Image.fromarray(mask_normalized, "L")
        out.append({'name': name0, 'img': mask_img})
        i += 1

    allMask = getMasks(out)
    return allMask


def getSegments(yoloModel, img1):
    classes, image, results1 = getClasses(yoloModel, img1)
    allMask = joinClasses(classes)
    return allMask


# Gradio UI
@spaces.GPU
def captionMaker(base64_img):
    return image_captioner(base64_img)[0]['generated_text']


def getDescript(image_captioner, img1):
    base64_img = image_to_base64(img1)
    caption = captionMaker(base64_img)
    return caption


def rmGPT(caption, remove_class):
    arstr = caption.split(' ')
    popular = get_most_similar_string(remove_class, arstr)
    ind = arstr.index(popular)
    new = []
    for i in range(len(arstr)):
        if i not in list(range(ind - 2, ind + 3)):
            new.append(arstr[i])
    return ' '.join(new)


@spaces.GPU
def ChangeOBJ(sdxl_m, img1, response, mask1):
    size = img1.size
    image = sdxl_m(prompt=response, image=img1, mask_image=mask1).images[0]
    return image.resize((size[0], size[1]))


def full_pipeline(image, target):
    img1 = Image.fromarray(image.astype('uint8'), 'RGB')
    allMask = getSegments(yoloModel, img1)
    tartget_to_remove = get_most_similar_string(target, list(allMask.keys()))
    caption = getDescript(image_captioner, img1)

    response = rmGPT(caption, tartget_to_remove)
    mask1 = allMask[tartget_to_remove]

    remimg = ChangeOBJ(sdxl, img1, response, mask1)

    return remimg, caption, response


iface = gr.Interface(
    fn=full_pipeline,
    inputs=[
        gr.Image(label="Upload Image"),
        gr.Textbox(label="What to delete?"),
    ],
    outputs=[
        gr.Image(label="Result Image", type="numpy"),
        gr.Textbox(label="Caption"),
        gr.Textbox(label="Message"),
    ],
    live=False
)

iface.launch()