Spaces:
Build error
Build error
from flask import Flask, request, jsonify , send_file | |
import torch | |
from PIL import Image, ImageOps | |
import base64 | |
from io import BytesIO | |
from utils_ootd import get_mask_location | |
from preprocess.openpose.run_openpose import OpenPose | |
from preprocess.humanparsing.run_parsing import Parsing | |
from ootd.inference_ootd_hd import OOTDiffusionHD | |
from ootd.inference_ootd_dc import OOTDiffusionDC | |
import spaces | |
app = Flask(__name__) | |
# Charger les modèles une seule fois au démarrage de l'application | |
openpose_model_hd = OpenPose(0) | |
parsing_model_hd = Parsing(0) | |
ootd_model_hd = OOTDiffusionHD(0) | |
openpose_model_dc = OpenPose(1) | |
parsing_model_dc = Parsing(1) | |
ootd_model_dc = OOTDiffusionDC(1) | |
# Définir la configuration GPU | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
category_dict = ['upperbody', 'lowerbody', 'dress'] | |
category_dict_utils = ['upper_body', 'lower_body', 'dresses'] | |
def process_hd(): | |
data = request.files | |
vton_img = data['vton_img'] | |
garm_img = data['garm_img'] | |
n_samples = int(request.form['n_samples']) | |
n_steps = int(request.form['n_steps']) | |
image_scale = float(request.form['image_scale']) | |
seed = int(request.form['seed']) | |
model_type = 'hd' | |
category = 0 # 0:upperbody; 1:lowerbody; 2:dress | |
# Charger les modèles en mémoire GPU | |
with torch.no_grad(): | |
openpose_model_hd.preprocessor.body_estimation.model.to(device) | |
ootd_model_hd.pipe.to(device) | |
ootd_model_hd.image_encoder.to(device) | |
ootd_model_hd.text_encoder.to(device) | |
garm_img = Image.open(garm_img).resize((768, 1024)) | |
vton_img = Image.open(vton_img).resize((768, 1024)) | |
keypoints = openpose_model_hd(vton_img.resize((384, 512))) | |
model_parse, _ = parsing_model_hd(vton_img.resize((384, 512))) | |
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints) | |
mask = mask.resize((768, 1024), Image.NEAREST) | |
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST) | |
masked_vton_img = Image.composite(mask_gray, vton_img, mask) | |
images = ootd_model_hd( | |
model_type=model_type, | |
category=category_dict[category], | |
image_garm=garm_img, | |
image_vton=masked_vton_img, | |
mask=mask, | |
image_ori=vton_img, | |
num_samples=n_samples, | |
num_steps=n_steps, | |
image_scale=image_scale, | |
seed=seed, | |
) | |
base64_images = [] | |
for img in images: | |
buffered = BytesIO() | |
img.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
base64_images.append(img_str) | |
return jsonify(images=base64_images) | |
def process_dc(): | |
data = request.files | |
vton_img = data['vton_img'] | |
garm_img = data['garm_img'] | |
category = request.form['category'] | |
n_samples = int(request.form['n_samples']) | |
n_steps = int(request.form['n_steps']) | |
image_scale = float(request.form['image_scale']) | |
seed = int(request.form['seed']) | |
model_type = 'dc' | |
if category == 'Upper-body': | |
category = 0 | |
elif category == 'Lower-body': | |
category = 1 | |
else: | |
category = 2 | |
# Charger les modèles en mémoire GPU | |
with torch.no_grad(): | |
openpose_model_dc.preprocessor.body_estimation.model.to(device) | |
ootd_model_dc.pipe.to(device) | |
ootd_model_dc.image_encoder.to(device) | |
ootd_model_dc.text_encoder.to(device) | |
garm_img = Image.open(garm_img).resize((768, 1024)) | |
vton_img = Image.open(vton_img).resize((768, 1024)) | |
keypoints = openpose_model_dc(vton_img.resize((384, 512))) | |
model_parse, _ = parsing_model_dc(vton_img.resize((384, 512))) | |
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints) | |
mask = mask.resize((768, 1024), Image.NEAREST) | |
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST) | |
masked_vton_img = Image.composite(mask_gray, vton_img, mask) | |
images = ootd_model_dc( | |
model_type=model_type, | |
category=category_dict[category], | |
image_garm=garm_img, | |
image_vton=masked_vton_img, | |
mask=mask, | |
image_ori=vton_img, | |
num_samples=n_samples, | |
num_steps=n_steps, | |
image_scale=image_scale, | |
seed=seed, | |
) | |
base64_images = [] | |
for img in images: | |
buffered = BytesIO() | |
img.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
base64_images.append(img_str) | |
return jsonify(images=base64_images) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |