File size: 4,984 Bytes
7de42a9
bd46aed
 
7de42a9
e139ea7
b390b12
 
 
 
 
4ef381b
b390b12
2fa7ad5
b390b12
eb387cb
b390b12
 
 
 
 
 
 
 
eb387cb
 
 
b390b12
 
 
4ef381b
 
2fa7ad5
 
 
 
 
 
 
 
 
b390b12
9024e34
 
 
eb387cb
9024e34
eb387cb
 
 
 
9024e34
 
 
 
 
b390b12
9024e34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b390b12
7de42a9
 
 
 
 
 
 
 
2fa7ad5
4ef381b
2fa7ad5
 
 
 
 
 
 
 
 
 
b390b12
9c8052b
9024e34
 
 
 
 
2fa7ad5
9024e34
eb387cb
b390b12
eb387cb
 
 
 
9024e34
 
 
122ecd4
 
b390b12
 
 
 
 
 
 
122ecd4
b390b12
 
 
 
 
 
 
 
 
 
 
7de42a9
 
 
 
 
 
 
 
b390b12
2fa7ad5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from flask import Flask, request, jsonify , send_file
import torch
from PIL import Image, ImageOps
import base64
from io import BytesIO
from utils_ootd import get_mask_location
from preprocess.openpose.run_openpose import OpenPose
from preprocess.humanparsing.run_parsing import Parsing
from ootd.inference_ootd_hd import OOTDiffusionHD
from ootd.inference_ootd_dc import OOTDiffusionDC
import spaces

app = Flask(__name__)

# Charger les modèles une seule fois au démarrage de l'application
openpose_model_hd = OpenPose(0)
parsing_model_hd = Parsing(0)
ootd_model_hd = OOTDiffusionHD(0)

openpose_model_dc = OpenPose(1)
parsing_model_dc = Parsing(1)
ootd_model_dc = OOTDiffusionDC(1)

# Définir la configuration GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

category_dict = ['upperbody', 'lowerbody', 'dress']
category_dict_utils = ['upper_body', 'lower_body', 'dresses']


@spaces.GPU
@app.route("/process_hd", methods=["POST"])
def process_hd():
    data = request.files
    vton_img = data['vton_img']
    garm_img = data['garm_img']
    n_samples = int(request.form['n_samples'])
    n_steps = int(request.form['n_steps'])
    image_scale = float(request.form['image_scale'])
    seed = int(request.form['seed'])

    model_type = 'hd'
    category = 0 # 0:upperbody; 1:lowerbody; 2:dress

    # Charger les modèles en mémoire GPU
    with torch.no_grad():
        openpose_model_hd.preprocessor.body_estimation.model.to(device)
        ootd_model_hd.pipe.to(device)
        ootd_model_hd.image_encoder.to(device)
        ootd_model_hd.text_encoder.to(device)
        
        garm_img = Image.open(garm_img).resize((768, 1024))
        vton_img = Image.open(vton_img).resize((768, 1024))
        keypoints = openpose_model_hd(vton_img.resize((384, 512)))
        model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))

        mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
        mask = mask.resize((768, 1024), Image.NEAREST)
        mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
        
        masked_vton_img = Image.composite(mask_gray, vton_img, mask)

        images = ootd_model_hd(
            model_type=model_type,
            category=category_dict[category],
            image_garm=garm_img,
            image_vton=masked_vton_img,
            mask=mask,
            image_ori=vton_img,
            num_samples=n_samples,
            num_steps=n_steps,
            image_scale=image_scale,
            seed=seed,
        )

        base64_images = []
        for img in images:
            buffered = BytesIO()
            img.save(buffered, format="JPEG")
            img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
            base64_images.append(img_str)

    return jsonify(images=base64_images)

@spaces.GPU
@app.route("/process_dc", methods=["POST"])
def process_dc():
    data = request.files
    vton_img = data['vton_img']
    garm_img = data['garm_img']
    category = request.form['category']
    n_samples = int(request.form['n_samples'])
    n_steps = int(request.form['n_steps'])
    image_scale = float(request.form['image_scale'])
    seed = int(request.form['seed'])

    model_type = 'dc'
    if category == 'Upper-body':
        category = 0
    elif category == 'Lower-body':
        category = 1
    else:
        category = 2

    # Charger les modèles en mémoire GPU
    with torch.no_grad():
        openpose_model_dc.preprocessor.body_estimation.model.to(device)
        ootd_model_dc.pipe.to(device)
        ootd_model_dc.image_encoder.to(device)
        ootd_model_dc.text_encoder.to(device)
        
        garm_img = Image.open(garm_img).resize((768, 1024))
        vton_img = Image.open(vton_img).resize((768, 1024))
        keypoints = openpose_model_dc(vton_img.resize((384, 512)))
        model_parse, _ = parsing_model_dc(vton_img.resize((384, 512)))

        mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
        mask = mask.resize((768, 1024), Image.NEAREST)
        mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
        
        masked_vton_img = Image.composite(mask_gray, vton_img, mask)

        images = ootd_model_dc(
            model_type=model_type,
            category=category_dict[category],
            image_garm=garm_img,
            image_vton=masked_vton_img,
            mask=mask,
            image_ori=vton_img,
            num_samples=n_samples,
            num_steps=n_steps,
            image_scale=image_scale,
            seed=seed,
        )
        base64_images = []
        for img in images:
            buffered = BytesIO()
            img.save(buffered, format="JPEG")
            img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
            base64_images.append(img_str)

    return jsonify(images=base64_images)

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)