Spaces:
Build error
Build error
File size: 4,984 Bytes
7de42a9 bd46aed 7de42a9 e139ea7 b390b12 4ef381b b390b12 2fa7ad5 b390b12 eb387cb b390b12 eb387cb b390b12 4ef381b 2fa7ad5 b390b12 9024e34 eb387cb 9024e34 eb387cb 9024e34 b390b12 9024e34 b390b12 7de42a9 2fa7ad5 4ef381b 2fa7ad5 b390b12 9c8052b 9024e34 2fa7ad5 9024e34 eb387cb b390b12 eb387cb 9024e34 122ecd4 b390b12 122ecd4 b390b12 7de42a9 b390b12 2fa7ad5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from flask import Flask, request, jsonify , send_file
import torch
from PIL import Image, ImageOps
import base64
from io import BytesIO
from utils_ootd import get_mask_location
from preprocess.openpose.run_openpose import OpenPose
from preprocess.humanparsing.run_parsing import Parsing
from ootd.inference_ootd_hd import OOTDiffusionHD
from ootd.inference_ootd_dc import OOTDiffusionDC
import spaces
app = Flask(__name__)
# Charger les modèles une seule fois au démarrage de l'application
openpose_model_hd = OpenPose(0)
parsing_model_hd = Parsing(0)
ootd_model_hd = OOTDiffusionHD(0)
openpose_model_dc = OpenPose(1)
parsing_model_dc = Parsing(1)
ootd_model_dc = OOTDiffusionDC(1)
# Définir la configuration GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
category_dict = ['upperbody', 'lowerbody', 'dress']
category_dict_utils = ['upper_body', 'lower_body', 'dresses']
@spaces.GPU
@app.route("/process_hd", methods=["POST"])
def process_hd():
data = request.files
vton_img = data['vton_img']
garm_img = data['garm_img']
n_samples = int(request.form['n_samples'])
n_steps = int(request.form['n_steps'])
image_scale = float(request.form['image_scale'])
seed = int(request.form['seed'])
model_type = 'hd'
category = 0 # 0:upperbody; 1:lowerbody; 2:dress
# Charger les modèles en mémoire GPU
with torch.no_grad():
openpose_model_hd.preprocessor.body_estimation.model.to(device)
ootd_model_hd.pipe.to(device)
ootd_model_hd.image_encoder.to(device)
ootd_model_hd.text_encoder.to(device)
garm_img = Image.open(garm_img).resize((768, 1024))
vton_img = Image.open(vton_img).resize((768, 1024))
keypoints = openpose_model_hd(vton_img.resize((384, 512)))
model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
mask = mask.resize((768, 1024), Image.NEAREST)
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
images = ootd_model_hd(
model_type=model_type,
category=category_dict[category],
image_garm=garm_img,
image_vton=masked_vton_img,
mask=mask,
image_ori=vton_img,
num_samples=n_samples,
num_steps=n_steps,
image_scale=image_scale,
seed=seed,
)
base64_images = []
for img in images:
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
base64_images.append(img_str)
return jsonify(images=base64_images)
@spaces.GPU
@app.route("/process_dc", methods=["POST"])
def process_dc():
data = request.files
vton_img = data['vton_img']
garm_img = data['garm_img']
category = request.form['category']
n_samples = int(request.form['n_samples'])
n_steps = int(request.form['n_steps'])
image_scale = float(request.form['image_scale'])
seed = int(request.form['seed'])
model_type = 'dc'
if category == 'Upper-body':
category = 0
elif category == 'Lower-body':
category = 1
else:
category = 2
# Charger les modèles en mémoire GPU
with torch.no_grad():
openpose_model_dc.preprocessor.body_estimation.model.to(device)
ootd_model_dc.pipe.to(device)
ootd_model_dc.image_encoder.to(device)
ootd_model_dc.text_encoder.to(device)
garm_img = Image.open(garm_img).resize((768, 1024))
vton_img = Image.open(vton_img).resize((768, 1024))
keypoints = openpose_model_dc(vton_img.resize((384, 512)))
model_parse, _ = parsing_model_dc(vton_img.resize((384, 512)))
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
mask = mask.resize((768, 1024), Image.NEAREST)
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
images = ootd_model_dc(
model_type=model_type,
category=category_dict[category],
image_garm=garm_img,
image_vton=masked_vton_img,
mask=mask,
image_ori=vton_img,
num_samples=n_samples,
num_steps=n_steps,
image_scale=image_scale,
seed=seed,
)
base64_images = []
for img in images:
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
base64_images.append(img_str)
return jsonify(images=base64_images)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)
|