Renecto commited on
Commit
22d033b
β€’
1 Parent(s): ac60c3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -381
app.py CHANGED
@@ -1,386 +1,12 @@
1
- import spaces
2
- import logging
3
- import math
4
  import gradio as gr
5
- from PIL import Image
6
- from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
7
- from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
8
- from src.unet_hacked_tryon import UNet2DConditionModel
9
- from transformers import (
10
- CLIPImageProcessor,
11
- CLIPVisionModelWithProjection,
12
- CLIPTextModel,
13
- CLIPTextModelWithProjection,
14
- )
15
- from diffusers import DDPMScheduler,AutoencoderKL
16
- from typing import List
17
-
18
- import torch
19
- import os
20
- from transformers import AutoTokenizer
21
- import numpy as np
22
- from utils_mask import get_mask_location
23
- from torchvision import transforms
24
- import apply_net
25
- from preprocess.humanparsing.run_parsing import Parsing
26
- from preprocess.openpose.run_openpose import OpenPose
27
- from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
28
- from torchvision.transforms.functional import to_pil_image
29
- from src.background_processor import BackgroundProcessor
30
-
31
- def pil_to_binary_mask(pil_image, threshold=0):
32
- np_image = np.array(pil_image)
33
- grayscale_image = Image.fromarray(np_image).convert("L")
34
- binary_mask = np.array(grayscale_image) > threshold
35
- mask = np.zeros(binary_mask.shape, dtype=np.uint8)
36
- for i in range(binary_mask.shape[0]):
37
- for j in range(binary_mask.shape[1]):
38
- if binary_mask[i,j] == True :
39
- mask[i,j] = 1
40
- mask = (mask*255).astype(np.uint8)
41
- output_mask = Image.fromarray(mask)
42
- return output_mask
43
-
44
-
45
- base_path = 'yisol/IDM-VTON'
46
- example_path = os.path.join(os.path.dirname(__file__), 'example')
47
-
48
- unet = UNet2DConditionModel.from_pretrained(
49
- base_path,
50
- subfolder="unet",
51
- torch_dtype=torch.float16,
52
- )
53
- unet.requires_grad_(False)
54
- tokenizer_one = AutoTokenizer.from_pretrained(
55
- base_path,
56
- subfolder="tokenizer",
57
- revision=None,
58
- use_fast=False,
59
- )
60
- tokenizer_two = AutoTokenizer.from_pretrained(
61
- base_path,
62
- subfolder="tokenizer_2",
63
- revision=None,
64
- use_fast=False,
65
- )
66
- noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
67
-
68
- text_encoder_one = CLIPTextModel.from_pretrained(
69
- base_path,
70
- subfolder="text_encoder",
71
- torch_dtype=torch.float16,
72
- )
73
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
74
- base_path,
75
- subfolder="text_encoder_2",
76
- torch_dtype=torch.float16,
77
- )
78
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
79
- base_path,
80
- subfolder="image_encoder",
81
- torch_dtype=torch.float16,
82
- )
83
- vae = AutoencoderKL.from_pretrained(base_path,
84
- subfolder="vae",
85
- torch_dtype=torch.float16,
86
- )
87
-
88
- # "stabilityai/stable-diffusion-xl-base-1.0",
89
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
90
- base_path,
91
- subfolder="unet_encoder",
92
- torch_dtype=torch.float16,
93
- )
94
 
95
- parsing_model = Parsing(0)
96
- openpose_model = OpenPose(0)
97
 
98
- UNet_Encoder.requires_grad_(False)
99
- image_encoder.requires_grad_(False)
100
- vae.requires_grad_(False)
101
- unet.requires_grad_(False)
102
- text_encoder_one.requires_grad_(False)
103
- text_encoder_two.requires_grad_(False)
104
- tensor_transfrom = transforms.Compose(
105
- [
106
- transforms.ToTensor(),
107
- transforms.Normalize([0.5], [0.5]),
108
- ]
109
- )
110
-
111
- pipe = TryonPipeline.from_pretrained(
112
- base_path,
113
- unet=unet,
114
- vae=vae,
115
- feature_extractor= CLIPImageProcessor(),
116
- text_encoder = text_encoder_one,
117
- text_encoder_2 = text_encoder_two,
118
- tokenizer = tokenizer_one,
119
- tokenizer_2 = tokenizer_two,
120
- scheduler = noise_scheduler,
121
- image_encoder=image_encoder,
122
- torch_dtype=torch.float16,
123
  )
124
- pipe.unet_encoder = UNet_Encoder
125
-
126
- # Standard size of shein images
127
- #WIDTH = int(4160/5)
128
- #HEIGHT = int(6240/5)
129
- # Standard size on which model is trained
130
- WIDTH = int(768)
131
- HEIGHT = int(1024)
132
- POSE_WIDTH = int(WIDTH/2) # int(WIDTH/2)
133
- POSE_HEIGHT = int(HEIGHT/2) #int(HEIGHT/2)
134
- ARM_WIDTH = "dc" # "hd" # hd -> full sleeve, dc for half sleeve
135
- CATEGORY = "upper_body" # "lower_body"
136
-
137
-
138
- def is_cropping_required(width, height):
139
- # If aspect ratio is 1.33, which is same as standard 3x4 ( 768x1024 ), then no need to crop, else crop
140
- aspect_ratio = round(height/width, 2)
141
- if aspect_ratio == 1.33:
142
- return False
143
- return True
144
-
145
-
146
- @spaces.GPU
147
- def start_tryon(human_img_dict,garm_img,garment_des, background_img, is_checked,is_checked_crop,denoise_steps,seed):
148
- logging.info("Starting try on")
149
- #device = "cuda"
150
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
151
-
152
- openpose_model.preprocessor.body_estimation.model.to(device)
153
- pipe.to(device)
154
- pipe.unet_encoder.to(device)
155
-
156
- human_img_orig = human_img_dict["background"].convert("RGB") # ImageEditor
157
- #human_img_orig = human_img_dict.convert("RGB") # Image
158
-
159
-
160
- """
161
- # Derive HEIGHT & WIDTH such that width is not more than 1000. This will cater to both Shein images (4160x6240) of 2:3 AR and model standard images ( 768x1024 ) of 3:4 AR
162
- WIDTH, HEIGHT = human_img_orig.size
163
- division_factor = math.ceil(WIDTH/1000)
164
- WIDTH = int(WIDTH/division_factor)
165
- HEIGHT = int(HEIGHT/division_factor)
166
- POSE_WIDTH = int(WIDTH/2)
167
- POSE_HEIGHT = int(HEIGHT/2)
168
- """
169
- # is_checked_crop as True if original AR is not same as 2x3 as expected by model
170
- w, h = human_img_orig.size
171
- is_checked_crop = is_cropping_required(w, h)
172
-
173
- garm_img= garm_img.convert("RGB").resize((WIDTH,HEIGHT))
174
- if is_checked_crop:
175
- # This will crop the image to make it Aspect Ratio of 3 x 4. And then at the end revert it back to original dimentions
176
- width, height = human_img_orig.size
177
- target_width = int(min(width, height * (3 / 4)))
178
- target_height = int(min(height, width * (4 / 3)))
179
-
180
- left = (width - target_width) / 2
181
- right = (width + target_width) / 2
182
- # for Landmark, model sizes are 594x879, so we need to reduce the height. In some case the garment on the model is
183
- # also getting removed when reducing size from bottom. So we will only reduce height from top for now
184
- top = (height - target_height) #top = (height - target_height) / 2
185
- bottom = height #bottom = (height + target_height) / 2
186
- cropped_img = human_img_orig.crop((left, top, right, bottom))
187
-
188
- crop_size = cropped_img.size
189
- human_img = cropped_img.resize((WIDTH, HEIGHT))
190
- else:
191
- human_img = human_img_orig.resize((WIDTH, HEIGHT))
192
-
193
- # Commenting out naize harmonization for now. We will have to integrate with Deep Learning based Harmonization methods
194
- # Do color transfer from background image for better image harmonization
195
- #if background_img:
196
- # human_img = BackgroundProcessor.intensity_transfer(human_img, background_img)
197
-
198
-
199
- if is_checked:
200
- # internally openpose_model is resizing human_img to resolution 384 if not passed as input
201
- keypoints = openpose_model(human_img.resize((POSE_WIDTH, POSE_HEIGHT)))
202
- model_parse, _ = parsing_model(human_img.resize((POSE_WIDTH, POSE_HEIGHT)))
203
- # internally get mask location function is resizing model_parse to 384x512 if width & height not passed
204
- mask, mask_gray = get_mask_location(ARM_WIDTH, CATEGORY, model_parse, keypoints)
205
- mask = mask.resize((WIDTH, HEIGHT))
206
- logging.info("Mask location on model identified")
207
- else:
208
- mask = pil_to_binary_mask(human_img_dict['layers'][0].convert("RGB").resize((WIDTH, HEIGHT)))
209
- # mask = transforms.ToTensor()(mask)
210
- # mask = mask.unsqueeze(0)
211
- mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
212
- mask_gray = to_pil_image((mask_gray+1.0)/2.0)
213
-
214
-
215
- human_img_arg = _apply_exif_orientation(human_img.resize((POSE_WIDTH,POSE_HEIGHT)))
216
- human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
217
-
218
-
219
-
220
- args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', device))
221
- # verbosity = getattr(args, "verbosity", None)
222
- pose_img = args.func(args,human_img_arg)
223
- pose_img = pose_img[:,:,::-1]
224
- pose_img = Image.fromarray(pose_img).resize((WIDTH,HEIGHT))
225
-
226
- with torch.no_grad():
227
- # Extract the images
228
- with torch.cuda.amp.autocast():
229
- with torch.no_grad():
230
- prompt = "model is wearing " + garment_des
231
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
232
- with torch.inference_mode():
233
- (
234
- prompt_embeds,
235
- negative_prompt_embeds,
236
- pooled_prompt_embeds,
237
- negative_pooled_prompt_embeds,
238
- ) = pipe.encode_prompt(
239
- prompt,
240
- num_images_per_prompt=1,
241
- do_classifier_free_guidance=True,
242
- negative_prompt=negative_prompt,
243
- )
244
-
245
- prompt = "a photo of " + garment_des
246
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
247
- if not isinstance(prompt, List):
248
- prompt = [prompt] * 1
249
- if not isinstance(negative_prompt, List):
250
- negative_prompt = [negative_prompt] * 1
251
- with torch.inference_mode():
252
- (
253
- prompt_embeds_c,
254
- _,
255
- _,
256
- _,
257
- ) = pipe.encode_prompt(
258
- prompt,
259
- num_images_per_prompt=1,
260
- do_classifier_free_guidance=False,
261
- negative_prompt=negative_prompt,
262
- )
263
-
264
-
265
-
266
- pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
267
- garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
268
- generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
269
- images = pipe(
270
- prompt_embeds=prompt_embeds.to(device,torch.float16),
271
- negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
272
- pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
273
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
274
- num_inference_steps=denoise_steps,
275
- generator=generator,
276
- strength = 1.0,
277
- pose_img = pose_img.to(device,torch.float16),
278
- text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
279
- cloth = garm_tensor.to(device,torch.float16),
280
- mask_image=mask,
281
- image=human_img,
282
- height=HEIGHT,
283
- width=WIDTH,
284
- ip_adapter_image = garm_img.resize((WIDTH,HEIGHT)),
285
- guidance_scale=2.0,
286
- )[0]
287
-
288
- if is_checked_crop:
289
- out_img = images[0].resize(crop_size)
290
- human_img_orig.paste(out_img, (int(left), int(top)))
291
- final_image = human_img_orig
292
- # return human_img_orig, mask_gray
293
- else:
294
- final_image = images[0]
295
- # return images[0], mask_gray
296
-
297
- # apply background to final image
298
- if background_img:
299
- logging.info("Adding background")
300
- final_image = BackgroundProcessor.replace_background_with_removebg(final_image, background_img)
301
- return final_image, mask_gray
302
- # return images[0], mask_gray
303
-
304
- garm_list = os.listdir(os.path.join(example_path,"cloth"))
305
- garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
306
-
307
- human_list = os.listdir(os.path.join(example_path,"human"))
308
- human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
309
-
310
- human_ex_list = []
311
- #human_ex_list = human_list_path # Image
312
- #""" if using ImageEditor instead of Image while taking input, use this - ImageEditor
313
- for ex_human in human_list_path:
314
- ex_dict= {}
315
- ex_dict['background'] = ex_human
316
- ex_dict['layers'] = None
317
- ex_dict['composite'] = None
318
- human_ex_list.append(ex_dict)
319
- #"""
320
- ##default human
321
-
322
-
323
- # api_open=True will allow this API to be hit using curl
324
- image_blocks = gr.Blocks().queue(api_open=True)
325
- with image_blocks as demo:
326
- gr.Markdown("## Virtual Try-On πŸ‘•πŸ‘”πŸ‘š")
327
- gr.Markdown("Upload an image of a person and an image of a garment ✨.")
328
- with gr.Row():
329
- with gr.Column():
330
- # changing from ImageEditor to Image to allow easy passing of data through API
331
- # instead of passing {"dictionary": <>} ( which is failing ), we can directly pass the image
332
- imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
333
- #imgs = gr.Image(sources='upload', type='pil',label='Human. Mask with pen or use auto-masking')
334
- with gr.Row():
335
- is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
336
- with gr.Row():
337
- is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
338
-
339
- example = gr.Examples(
340
- inputs=imgs,
341
- examples_per_page=10,
342
- examples=human_ex_list
343
- )
344
-
345
- with gr.Column():
346
- garm_img = gr.Image(label="Garment", sources='upload', type="pil")
347
- with gr.Row(elem_id="prompt-container"):
348
- with gr.Row():
349
- prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
350
- example = gr.Examples(
351
- inputs=garm_img,
352
- examples_per_page=8,
353
- examples=garm_list_path)
354
-
355
- with gr.Column():
356
- background_img = gr.Image(label="Background", sources='upload', type="pil")
357
-
358
- with gr.Column():
359
- with gr.Row():
360
- image_out = gr.Image(label="Output", elem_id="output-img", show_share_button=False)
361
- with gr.Row():
362
- masked_img = gr.Image(label="Masked image output", elem_id="masked-img", show_share_button=False)
363
- """
364
- with gr.Column():
365
- # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
366
- masked_img = gr.Image(label="Masked image output", elem_id="masked-img", show_share_button=False)
367
- with gr.Column():
368
- # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
369
- image_out = gr.Image(label="Output", elem_id="output-img", show_share_button=False)
370
- """
371
-
372
-
373
-
374
- with gr.Column():
375
- try_button = gr.Button(value="Try-on")
376
- with gr.Accordion(label="Advanced Settings", open=False):
377
- with gr.Row():
378
- denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
379
- seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
380
-
381
-
382
-
383
- try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, background_img, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
384
 
385
-
386
- image_blocks.launch()
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ def greet(name, intensity):
4
+ return "Hello, " + name + "!" * int(intensity)
5
 
6
+ demo = gr.Interface(
7
+ fn=greet,
8
+ inputs=["text", "slider"],
9
+ outputs=["text"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ demo.launch()