File size: 9,338 Bytes
4debf12
 
 
a6d790c
4debf12
a6d790c
4debf12
a6d790c
 
4debf12
 
a6d790c
4debf12
 
 
 
 
 
 
 
 
a6d790c
4debf12
 
a6d790c
4debf12
a6d790c
 
4debf12
 
a6d790c
 
4debf12
 
 
 
 
 
 
 
a6d790c
4debf12
 
a6d790c
 
4debf12
a6d790c
4debf12
 
 
a6d790c
 
656d6c8
 
 
a6d790c
656d6c8
 
 
98c3bc5
a6d790c
98c3bc5
 
 
 
 
656d6c8
 
 
98c3bc5
656d6c8
 
98c3bc5
4debf12
 
a6d790c
4debf12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6d790c
 
 
4debf12
 
 
98c3bc5
4debf12
98c3bc5
4debf12
1e9ddc5
6a54251
4debf12
 
4c930e4
98c3bc5
4debf12
 
a6d790c
4debf12
 
a6d790c
 
 
 
 
 
 
 
 
 
 
a8f49b5
4debf12
 
64da8ce
8d49805
a6d790c
4debf12
 
 
 
 
 
 
 
 
dc7ced4
4debf12
42e59f2
98c3bc5
5042117
4debf12
 
 
 
 
a6d790c
4debf12
 
 
a6d790c
4debf12
 
 
b94b459
 
 
a6d790c
4debf12
a6d790c
4debf12
a6d790c
4debf12
a6d790c
 
4debf12
 
 
 
64da8ce
4debf12
a6d790c
 
 
 
 
 
4debf12
 
 
 
 
 
 
b94b459
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import datetime

import toml
import gradio as gr
from PIL import PngImagePlugin

from pnginfo import read_info_from_image, send_paras
from images_history import img_history_ui
from utils import set_token, generate_novelai_image, image_from_bytes

client_config = toml.load("config.toml")['client']
today_count = 0
today = datetime.date.today().strftime('%Y-%m-%d')

def get_count():
    global today_count, today
    now = datetime.date.today().strftime('%Y-%m-%d')
    if now != today:
        today = now
        today_count = 0
    return "今日已生成图片" + str(today_count) + "张"

def control_ui():
    prompt = gr.TextArea(elem_id='txt2img_prompt', label="提示词", lines=3)
    quality_tags = gr.TextArea(
        elem_id='txt2img_qua_prompt', label="质量词", lines=1,
        value=client_config['default_quality'],
    )
    neg_prompt = gr.TextArea(
        elem_id='txt2img_neg_prompt', label="负面词", lines=1,
        value=client_config['default_neg'],
    )
    with gr.Row():
        sampler = gr.Dropdown(
            choices=[
                "k_euler", "k_euler_ancestral", "k_dpmpp_2s_ancestral", 
                "k_dpmpp_2m", "k_dpmpp_sde", "ddim_v3"
            ],
            value="k_euler",
            label="采样器",
            interactive=True
        )
        scale = gr.Slider(label="CFG Scale", value=5.0, minimum=1, maximum=10, step=0.1)
        steps = gr.Slider(label="步数", value=28, minimum=1, maximum=28, step=1)
    with gr.Row():
        seed = gr.Number(label="种子", value=-1, step=1, maximum=2**32-1, minimum=-1, scale=3)
        rand_seed = gr.Button('🎲️', scale=1)
        reuse_seed = gr.Button('♻️', scale=1)
    with gr.Row():
        width = gr.Slider(label="宽度", value=1024, minimum=64, maximum=2048, step=64)
        height = gr.Slider(label="高度", value=1024, minimum=64, maximum=2048, step=64)
    with gr.Row():
        with gr.Column():
            with gr.Accordion('风格迁移', open=False):
                ref_image = gr.Image(label="上传图片", value=None, sources=["upload", "clipboard", "webcam"], interactive=True, type="pil")
                info_extract = gr.Slider(label='参考信息提取', value=1, minimum=0, maximum=1, step=0.1)
                ref_str = gr.Slider(label='参考强度', value=0.6, minimum=0, maximum=1, step=0.1)
                reuse_img_vibe = gr.Button(value='使用上一次生成的图片')
            with gr.Accordion('图生图', open=False):
                i2i_image = gr.Image(label="上传图片", value=None, sources=["upload", "clipboard", "webcam"], interactive=True, type="pil")
                i2i_str = gr.Slider(label='去噪强度', value=0.7, minimum=0, maximum=0.99, step=0.01)
                i2i_noise = gr.Slider(label='噪声', value=0, minimum=0, maximum=1, step=0.1)
                reuse_img_i2i = gr.Button(value='使用上一次生成的图片')
            '''
            with gr.Accordion('局部重绘', open=False, visible=False):
                with gr.Row():
                    use_inp = gr.Checkbox(label='启用', value=False)
                    overlay = gr.Checkbox(label='覆盖原图', value=True)
                inp_img = gr.ImageEditor(label="上传图片", value=None, sources=["upload"], interactive=True, type="pil", eraser=False, transforms=None, brush='imagemask')
                inp_str = gr.Slider(label='重绘强度', value=0.7, minimum=0, maximum=0.99, step=0.01)
                reuse_img_inp = gr.Button(value='使用上一次生成的图片')
            '''
    with gr.Row():
        with gr.Column():
            with gr.Accordion('高级选项', open=False):
                scheduler = gr.Dropdown(
                    choices=[
                        "native", "karras", "exponential", "polyexponential"
                    ],
                    value="native",
                    label="Scheduler",
                    interactive=True
                )
                with gr.Row():
                    smea = gr.Checkbox(False, label="SMEA")
                    dyn = gr.Checkbox(False, label="SMEA DYN")
                with gr.Row():
                    dyn_threshold = gr.Checkbox(False, label="Dynamic Thresholding")
                    cfg_rescale = gr.Slider(0, 1, 0, step=0.01, label="CFG rescale")
        with gr.Column():
            gr.Textbox(value=get_count, label='使用统计', every=10)
            save = gr.Checkbox(value=True, label='云端保存图片')
    gen_btn = gr.Button(value="生成", variant="primary")
    rand_seed.click(fn=lambda: -1, inputs=None, outputs=seed)
    width.change(lambda w, h: h if w*h<=1024*1024 else (1024*1024//w//64)*64, [width, height], height)
    height.change(lambda w, h: w if w*h<=1024*1024 else (1024*1024//h//64)*64, [width, height], width)
    return gen_btn,[prompt, quality_tags, neg_prompt, seed, scale, width, height, steps, sampler, scheduler, smea, dyn, dyn_threshold, cfg_rescale, ref_image, info_extract, ref_str, i2i_image, i2i_str, i2i_noise], [save, rand_seed, reuse_seed, reuse_img_vibe, reuse_img_i2i]

def generate(prompt, quality_tags, neg_prompt, seed, scale, width, height, steps, sampler, scheduler, smea, dyn, dyn_threshold, cfg_rescale, ref_image, info_extract, ref_str, i2i_image, i2i_str, i2i_noise, save):
    global today_count
    set_token(os.environ.get('token'))
    img_data, payload = generate_novelai_image(
        f"{prompt}, {quality_tags}", neg_prompt, seed, scale, 
        width, height, steps, sampler, scheduler, 
        smea, dyn, dyn_threshold, cfg_rescale, ref_image, info_extract, ref_str,
        i2i_image, i2i_str, i2i_noise
    )
    if not isinstance(img_data, bytes):
        return gr.Image(value=None), payload
    today_count += 1
    img = image_from_bytes(img_data)
    if save:
        save_path = client_config['save_path']
        today = datetime.date.today().strftime('%Y-%m-%d')
        today_path = os.path.join(save_path, today)
        if not os.path.exists(today_path):
            os.makedirs(today_path, mode=777, exist_ok=True)
        filename = str(today_count).rjust(5, '0') + '-' + str(payload['parameters']['seed']) + '.png'
        pnginfo_data = PngImagePlugin.PngInfo()
        for k, v in img.info.items():
            pnginfo_data.add_text(k, str(v))
        img.save(os.path.join(today_path, filename), pnginfo=pnginfo_data)
    return img, payload

def preview_ui():
    with gr.Blocks(css='#preview_image { height: 100%;}'):
        image = gr.Image(elem_id='preview_image', interactive=False, type='pil')
        info = gr.JSON(value={}, label="生成信息")
    return image, info

def main_ui():
    with gr.Blocks() as page:
        with gr.Row(variant="panel"):
            with gr.Column():
                gen_btn, paras, others = control_ui()
            with gr.Column():
                image, info = preview_ui()
    gen_btn.click(generate, paras + [others[0]], [image, info])
    others[2].click(lambda o, s: o if len(s) == 0 else s['parameters']['seed'], inputs=[paras[3], info], outputs=paras[3])
    others[3].click(lambda i: i, inputs=image, outputs=paras[14])
    others[4].click(lambda i: i, inputs=image, outputs=paras[17])
    return page, paras[:14]

def util_ui():
    with gr.Blocks(analytics_enabled=False) as page:
        with gr.Row(equal_height=False):
            with gr.Column(variant='panel'):
                image = gr.Image(label="上传图片", sources=["upload"], interactive=True, type="pil")
            with gr.Column(variant='panel'):
                info = gr.HTML()
                items = gr.JSON(value=None, visible=False)
                png2main = gr.Button('参数发送到文生图')
    return page, png2main, items, info, image

def ui():
    head = ''
    for f in sorted(os.listdir('./tagcomplete/javascript')):
        head += f'<script type="text/javascript" src="file=tagcomplete/javascript/{f}"></script>\n'
    with gr.Blocks(title="NAI Client", head=head) as website:
        with gr.Tabs():
            with gr.TabItem("文生图", elem_id="client_ui_main"):
                _, paras = main_ui()
            with gr.TabItem("图片信息读取"):
                _, png2main, png_items, info, image = util_ui()
            with gr.TabItem("云端图片浏览") as tab:
                gal2main, gal_items = img_history_ui(tab)
            png2main.click(fn=send_paras,
                           inputs=[png_items] + paras,
                           outputs=paras)
            png2main.click(fn=None,
                           js="(x) => { if (x !== null) document.getElementById('client_ui_main-button').click(); return null; }",
                           inputs=image)
            gal2main.click(fn=send_paras,
                           inputs=[gal_items] + paras,
                           outputs=paras)
            gal2main.click(fn=None,
                           js="(x) => { if (x !== null) document.getElementById('client_ui_main-button').click(); return null; }",
                           inputs=gal_items)
            image.change(read_info_from_image, inputs=image, outputs=[info, png_items])
    return website


if __name__ == '__main__':
    website = ui()
    website.queue()
    website.launch(auth=(os.environ.get('account'), os.environ.get('password')), allowed_paths=['tagcomplete'])