File size: 8,579 Bytes
1457e31
 
f762231
 
1457e31
 
 
 
 
 
 
 
 
 
 
 
 
f30af9d
 
3d0272c
 
 
 
 
 
 
 
 
 
 
 
 
 
f30af9d
 
 
 
 
 
 
 
 
 
 
 
e68a153
 
 
 
 
 
 
eebc5e4
 
 
1457e31
f30af9d
 
e68a153
3d0272c
 
1457e31
 
 
 
 
 
 
 
f762231
 
 
 
 
1457e31
f762231
b4297d4
e68a153
1457e31
 
 
 
3d0272c
f762231
b4297d4
f30af9d
 
1457e31
 
 
3d0272c
1457e31
 
b4297d4
3d0272c
1457e31
 
e68a153
1457e31
f762231
1457e31
 
 
 
 
 
a67de12
1457e31
 
 
 
 
 
 
 
f762231
1457e31
 
 
f762231
1457e31
 
 
 
f762231
1457e31
 
 
d417cb0
 
 
 
 
 
1457e31
 
 
 
 
 
 
 
933eaf9
f30af9d
933eaf9
 
1457e31
 
3d0272c
 
 
 
 
e68a153
 
3d0272c
 
f30af9d
 
933eaf9
f30af9d
1457e31
 
 
f762231
1457e31
f762231
 
1457e31
 
f762231
 
1457e31
f762231
 
 
 
1457e31
 
 
 
 
 
 
 
 
f762231
e68a153
d417cb0
 
f762231
f30af9d
 
f762231
d417cb0
 
f762231
d417cb0
 
1457e31
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

from utils import *
from config import *



temp_examples = get_temps_examples(taskType)
user_examples = get_user_examples(taskType)
showcase_examples = get_showcase_examples(taskType)
user_recorder = UserRecorder()


css = """
.gradio-container {width: 85% !important}
"""


def onClick(temp_image, user_image, caption_text, token_text, 
        param4_text, param5_text, request: gr.Request):

    user_mask = None
    if taskType=='2':
        user_mask = user_image['layers'][0]
        user_image = user_image['background']
        user_mask = (user_mask.sum(2)>0).astype(np.uint8)*255
        user_image = np.array(Image.fromarray(user_image).convert('RGB'))
        if user_image.sum()==0:
            yield None, "please upload a photo!!!"
            return None, "please upload a photo!!!"
        if user_mask.sum()==0:
            yield None, "please draw a area!!!"
            return None, "please draw a area!!!" 

    if taskType=='7':
        try:
            param4_text, param5_text = str(float(param4_text)), str(float(param5_text))
        except ValueError:
            yield None, "Invalid width/height: Please enter a valid float"
            return None, "Invalid width/height: Please enter a valid float"
        if len(caption_text)==0:
            yield None, "Please enter English caption text !!! "
            return None, "Please enter English caption text !!! "
    else:
        param4_text, param5_text = '', ''

    if taskType in ['8', '9']:
        if len(caption_text)==0:
            yield None, "Please enter caption !!! "
            return None, "Please enter caption !!! "
        if taskType=='9':
            temp_image, param4_text = '1536x1024.jpg', 'realistic_image'

    # print("======> temp_image ", type(temp_image))
    # print("======> user_image ", type(user_image))
    # print("======> caption_text ", type(caption_text))
    if temp_image is None: 
        yield None, "please choose a template!!!"
        return None, "please choose a template!!!"
    if user_image is None and taskType not in ['8', '9']: 
        yield None, "please upload a photo!!!"
        return None, "please upload a photo!!!"

    try:
        client_ip = request.client.host
        x_forwarded_for = dict(request.headers).get('x-forwarded-for')
        if x_forwarded_for: client_ip = x_forwarded_for
        if not check_region_warp(client_ip):
            return None, "Failed !!! Our server is under maintenance, please try again later"
            
        # 检查是否可以继续试用
        check_res, info = user_recorder.check_record(ip=client_ip, token=token_text)
        if not check_res:
            yield None, info
            return None, info

        # 上传用户照片
        yield None, "start to upload, please wait..."
        upload_url, uploadm_url = upload_user_img_mask(client_ip, user_image, user_mask, taskType)
        if len(upload_url)==0:
            yield None, "fail to upload"
            return None, "fail to upload"

        # return 
        # 发布任务
        yield None, "start to public, please wait..."
        taskId = publicSelfitTask(upload_url, uploadm_url, temp_image, 
            caption_text, param4_text, param5_text)
        if not taskId:
            yield None, "fail to public task..."
            return None, "fail to public task..."
            
        max_try = 30
        wait_s = 3
        yield None, "start to process, please wait..."
        # time.sleep(2)
        for i in range(max_try):
            time.sleep(wait_s)
            taskStatus = getTaskRes(taskId, taskType)
            if taskStatus is None: continue
            user_recorder.save_record(taskStatus, ip=client_ip, token=token_text)

            status = taskStatus['status']
            if status in ['FAILED', 'CANCELLED', 'TIMED_OUT', ]:
                yield None, f"task failed, query {i}, status {status}"
                return None, f"task failed, query {i}, status {status}"
            elif status in ['IN_QUEUE', 'IN_PROGRESS', 'IN_QUEUE', ]:
                yield None, f"task is on processing, query {i}, status {status}, please do not exit !!!"
            elif status=='COMPLETED':
                out = taskStatus['output']['job_results']['output1']
                yield out, f"task is COMPLETED"
                return out, f"{i} task COMPLETED"
        yield None, "fail to query task.."
        return None, "fail to query task.."
    except Exception as e:
        print(e)
        raise e
        yield None, "fail to create task"
        return None, "fail to create task"

def onLoad(token_text, request: gr.Request):
    client_ip = request.client.host
    x_forwarded_for = dict(request.headers).get('x-forwarded-for')
    if x_forwarded_for:
        client_ip = x_forwarded_for
    his_datas, total_n, msg = user_recorder.get_record(ip=client_ip, token=token_text)
    left_n = max(0, LimitTask-total_n)
    his_datas.append(msg)
    his_datas.append(f"Submit ({left_n} attempts left)")

    url_params = dict(request.query_params)
    if 'token' in url_params:
        token_text = url_params['token']

    return token_text, *his_datas

with gr.Blocks(css=css) as demo:
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Row():
        with gr.Column():
            with gr.Column():
                temp_image = gr.Image(sources='clipboard', type="filepath", label=TempLabel, 
                    value=temp_examples[0][0], visible=TempVisible, interactive=TempInter)
                temp_example = gr.Examples(inputs=[temp_image], examples_per_page=9, 
                    examples=temp_examples, visible=TempVisible)
        with gr.Column():
            with gr.Column():
                if taskType=='2':
                    brush = gr.Brush(colors=['#FF0000'], color_mode='fixed')
                    user_image = gr.ImageEditor(value=None, type="numpy", 
                        eraser=False, brush=brush ,layers=False, sources=['upload',],
                        transforms=[], label=UserLabel)
                elif taskType in ['8', '9']:
                    user_image = gr.Image(value=temp_examples[0][0], type="numpy", label=UserLabel, interactive=False)
                else:
                    user_image = gr.Image(value=None, type="numpy", label=UserLabel)
                param4_text = gr.Textbox(value="0.5", interactive=True, label=Param4Label, visible=Param4Visible)
                param5_text = gr.Textbox(value="0.5", interactive=True, label=Param5Label, visible=Param5Visible)
                caption_text = gr.Textbox(value="", interactive=True, label=CaptionLabel, visible=CapVisible)

        with gr.Column():
            with gr.Column():
                res_image = gr.Image(label="generate image", value=None, type="filepath")
                info_text = gr.Markdown(value="", label='Runtime Info')  # 创建 Markdown 输出组件
                run_button = gr.Button(value="Submit")
                token_text = gr.Textbox(value="", interactive=True, 
                    label='Enter Your Api Key (optional)', visible=is_show_token)
            
    with gr.Column():
        show_case = gr.Examples(examples=showcase_examples,
            inputs=[temp_image, user_image, res_image, ],label=None)
    with gr.Tab('history'):
        with gr.Row():  # 用 Row 包裹按钮
            with gr.Column(scale=0.5):  # Button 占用 Row 的一半
                refresh_button = gr.Button("Refresh History", size="small")
        MK02 = gr.Markdown(value="")  # 示例 Markdown 内容

        with gr.Row():
            his_input1 = gr.HTML()
            his_output1 = gr.HTML()
        with gr.Row():
            his_input2 = gr.HTML()
            his_output2 = gr.HTML()
        with gr.Row():
            his_input3 = gr.HTML()
            his_output3 = gr.HTML()
            
    # outputs_onload = [his_input1, his_output1, his_input2, his_output2, his_input3, his_output3,
    #     MK02, run_button]

    run_button.click(fn=onClick, inputs=[temp_image, user_image, caption_text, 
        token_text, param4_text, param5_text], outputs=[res_image, info_text])

    refresh_button.click(fn=onLoad, inputs=[token_text], outputs=[token_text, his_input1, his_output1, his_input2, his_output2, his_input3, his_output3,
        MK02, run_button])

    demo.load(onLoad, inputs=[token_text], outputs=[token_text, his_input1, his_output1, his_input2, his_output2, his_input3, his_output3,
        MK02, run_button])


if __name__ == "__main__":

    demo.queue(max_size=50)
    demo.launch(server_name='0.0.0.0')