OOTDiffusion / run /gradio_ootd.py
Saad0KH's picture
Update run/gradio_ootd.py
4ef381b verified
raw
history blame
4.98 kB
from flask import Flask, request, jsonify , send_file
import torch
from PIL import Image, ImageOps
import base64
from io import BytesIO
from utils_ootd import get_mask_location
from preprocess.openpose.run_openpose import OpenPose
from preprocess.humanparsing.run_parsing import Parsing
from ootd.inference_ootd_hd import OOTDiffusionHD
from ootd.inference_ootd_dc import OOTDiffusionDC
import spaces
app = Flask(__name__)
# Charger les modèles une seule fois au démarrage de l'application
openpose_model_hd = OpenPose(0)
parsing_model_hd = Parsing(0)
ootd_model_hd = OOTDiffusionHD(0)
openpose_model_dc = OpenPose(1)
parsing_model_dc = Parsing(1)
ootd_model_dc = OOTDiffusionDC(1)
# Définir la configuration GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
category_dict = ['upperbody', 'lowerbody', 'dress']
category_dict_utils = ['upper_body', 'lower_body', 'dresses']
@spaces.GPU
@app.route("/process_hd", methods=["POST"])
def process_hd():
data = request.files
vton_img = data['vton_img']
garm_img = data['garm_img']
n_samples = int(request.form['n_samples'])
n_steps = int(request.form['n_steps'])
image_scale = float(request.form['image_scale'])
seed = int(request.form['seed'])
model_type = 'hd'
category = 0 # 0:upperbody; 1:lowerbody; 2:dress
# Charger les modèles en mémoire GPU
with torch.no_grad():
openpose_model_hd.preprocessor.body_estimation.model.to(device)
ootd_model_hd.pipe.to(device)
ootd_model_hd.image_encoder.to(device)
ootd_model_hd.text_encoder.to(device)
garm_img = Image.open(garm_img).resize((768, 1024))
vton_img = Image.open(vton_img).resize((768, 1024))
keypoints = openpose_model_hd(vton_img.resize((384, 512)))
model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
mask = mask.resize((768, 1024), Image.NEAREST)
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
images = ootd_model_hd(
model_type=model_type,
category=category_dict[category],
image_garm=garm_img,
image_vton=masked_vton_img,
mask=mask,
image_ori=vton_img,
num_samples=n_samples,
num_steps=n_steps,
image_scale=image_scale,
seed=seed,
)
base64_images = []
for img in images:
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
base64_images.append(img_str)
return jsonify(images=base64_images)
@spaces.GPU
@app.route("/process_dc", methods=["POST"])
def process_dc():
data = request.files
vton_img = data['vton_img']
garm_img = data['garm_img']
category = request.form['category']
n_samples = int(request.form['n_samples'])
n_steps = int(request.form['n_steps'])
image_scale = float(request.form['image_scale'])
seed = int(request.form['seed'])
model_type = 'dc'
if category == 'Upper-body':
category = 0
elif category == 'Lower-body':
category = 1
else:
category = 2
# Charger les modèles en mémoire GPU
with torch.no_grad():
openpose_model_dc.preprocessor.body_estimation.model.to(device)
ootd_model_dc.pipe.to(device)
ootd_model_dc.image_encoder.to(device)
ootd_model_dc.text_encoder.to(device)
garm_img = Image.open(garm_img).resize((768, 1024))
vton_img = Image.open(vton_img).resize((768, 1024))
keypoints = openpose_model_dc(vton_img.resize((384, 512)))
model_parse, _ = parsing_model_dc(vton_img.resize((384, 512)))
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
mask = mask.resize((768, 1024), Image.NEAREST)
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
images = ootd_model_dc(
model_type=model_type,
category=category_dict[category],
image_garm=garm_img,
image_vton=masked_vton_img,
mask=mask,
image_ori=vton_img,
num_samples=n_samples,
num_steps=n_steps,
image_scale=image_scale,
seed=seed,
)
base64_images = []
for img in images:
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
base64_images.append(img_str)
return jsonify(images=base64_images)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)