0Tick commited on
Commit
5ac3442
1 Parent(s): fc12774

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -1
app.py CHANGED
@@ -1,3 +1,184 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- gr.Interface.load("models/0Tick/e621TagAutocomplete").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import os
3
+ import time
4
+
5
+ import torch
6
+ import transformers
7
+
8
  import gradio as gr
9
 
10
+ class FormRow(FormComponent, gr.Row):
11
+ """Same as gr.Row but fits inside gradio forms"""
12
+
13
+ def get_block_name(self):
14
+ return "row"
15
+
16
+ def wrap_gradio_gpu_call(func, extra_outputs=None):
17
+ def f(*args, **kwargs):
18
+ res = func(*args, **kwargs)
19
+ return res
20
+ return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
21
+
22
+
23
+ class Model:
24
+ name = None
25
+ model = None
26
+ tokenizer = None
27
+
28
+ available_models = []
29
+ current = Model()
30
+ job_count = 1
31
+
32
+ base_dir = scripts.basedir()
33
+ models_dir = os.path.join(base_dir, "models")
34
+
35
+
36
+ def device():
37
+ return devices.cpu
38
+
39
+
40
+ def list_available_models():
41
+ available_models = ["0Tick/e621TagAutocomplete","0Tick/danbooruTagAutocomplete"]
42
+
43
+
44
+ def get_model_path(name):
45
+ dirname = os.path.join(models_dir, name)
46
+ if not os.path.isdir(dirname):
47
+ return name
48
+
49
+ return dirname
50
+
51
+
52
+ def generate_batch(input_ids, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p):
53
+ top_p = float(top_p) if sampling_mode == 'Top P' else None
54
+ top_k = int(top_k) if sampling_mode == 'Top K' else None
55
+
56
+ outputs = current.model.generate(
57
+ input_ids,
58
+ do_sample=True,
59
+ temperature=max(float(temperature), 1e-6),
60
+ repetition_penalty=repetition_penalty,
61
+ length_penalty=length_penalty,
62
+ top_p=top_p,
63
+ top_k=top_k,
64
+ num_beams=int(num_beams),
65
+ min_length=min_length,
66
+ max_length=max_length,
67
+ pad_token_id=current.tokenizer.pad_token_id or current.tokenizer.eos_token_id
68
+ )
69
+ texts = current.tokenizer.batch_decode(outputs, skip_special_tokens=True)
70
+ return texts
71
+
72
+
73
+ def model_selection_changed(model_name):
74
+ if model_name == "None":
75
+ current.tokenizer = None
76
+ current.model = None
77
+ current.name = None
78
+
79
+ devices.torch_gc()
80
+
81
+ def generate(id_task, model_name, batch_count, batch_size, text, *args):
82
+ job_count = batch_count
83
+
84
+ if current.name != model_name:
85
+ current.tokenizer = None
86
+ current.model = None
87
+ current.name = None
88
+
89
+ if model_name != 'None':
90
+ path = get_model_path(model_name)
91
+ current.tokenizer = transformers.AutoTokenizer.from_pretrained(path)
92
+ current.model = transformers.AutoModelForCausalLM.from_pretrained(path)
93
+ current.name = model_name
94
+
95
+ assert current.model, 'No model available'
96
+ assert current.tokenizer, 'No tokenizer available'
97
+
98
+ current.model.to(device())
99
+
100
+ input_ids = current.tokenizer(text, return_tensors="pt").input_ids
101
+ if input_ids.shape[1] == 0:
102
+ input_ids = torch.asarray([[current.tokenizer.bos_token_id]], dtype=torch.long)
103
+ input_ids = input_ids.to(device())
104
+ input_ids = input_ids.repeat((batch_size, 1))
105
+
106
+ markup = '<table><tbody>'
107
+
108
+ index = 0
109
+ for i in range(batch_count):
110
+ texts = generate_batch(input_ids, *args)
111
+ for generated_text in texts:
112
+ index += 1
113
+ markup += f"""
114
+ <tr>
115
+ <td>
116
+ <div class="prompt gr-box gr-text-input">
117
+ <p id='promptgen_res_{index}'>{html.escape(generated_text)}</p>
118
+ </div>
119
+ </td>
120
+ <a class='gr-button gr-button-lg gr-button-secondary' onclick="navigator.clipboard.writeText(gradioApp().getElementById('promptgen_res_{index}';).textContent)">copy</a>
121
+ </tr>
122
+ """
123
+
124
+ markup += '</tbody></table>'
125
+
126
+ return markup, ''
127
+
128
+
129
+
130
+ list_available_models()
131
+
132
+ with gr.Blocks(analytics_enabled=False) as space:
133
+ with gr.Row():
134
+ with gr.Column(scale=80):
135
+ prompt = gr.Textbox(label="Prompt", elem_id="promptgen_prompt", show_label=False, lines=2, placeholder="Beginning of the prompt (press Ctrl+Enter or Alt+Enter to generate)").style(container=False)
136
+ with gr.Column(scale=10):
137
+ submit = gr.Button('Generate', elem_id="promptgen_generate", variant='primary')
138
+
139
+ with gr.Row(elem_id="promptgen_main"):
140
+ with gr.Column(variant="compact"):
141
+ selected_text = gr.TextArea(elem_id='promptgen_selected_text', visible=False)
142
+
143
+ with FormRow():
144
+ model_selection = gr.Dropdown(label="Model", elem_id="promptgen_model", value=available_models[0], choices=["None"] + available_models)
145
+
146
+ with FormRow():
147
+ sampling_mode = gr.Radio(label="Sampling mode", elem_id="promptgen_sampling_mode", value="Top K", choices=["Top K", "Top P"])
148
+ top_k = gr.Slider(label="Top K", elem_id="promptgen_top_k", value=12, minimum=1, maximum=50, step=1)
149
+ top_p = gr.Slider(label="Top P", elem_id="promptgen_top_p", value=0.15, minimum=0, maximum=1, step=0.001)
150
+
151
+ with gr.Row():
152
+ num_beams = gr.Slider(label="Number of beams", elem_id="promptgen_num_beams", value=1, minimum=1, maximum=8, step=1)
153
+ temperature = gr.Slider(label="Temperature", elem_id="promptgen_temperature", value=1, minimum=0, maximum=4, step=0.01)
154
+ repetition_penalty = gr.Slider(label="Repetition penalty", elem_id="promptgen_repetition_penalty", value=1, minimum=1, maximum=4, step=0.01)
155
+
156
+ with FormRow():
157
+ length_penalty = gr.Slider(label="Length preference", elem_id="promptgen_length_preference", value=1, minimum=-10, maximum=10, step=0.1)
158
+ min_length = gr.Slider(label="Min length", elem_id="promptgen_min_length", value=20, minimum=1, maximum=400, step=1)
159
+ max_length = gr.Slider(label="Max length", elem_id="promptgen_max_length", value=150, minimum=1, maximum=400, step=1)
160
+
161
+ with FormRow():
162
+ batch_count = gr.Slider(label="Batch count", elem_id="promptgen_batch_count", value=1, minimum=1, maximum=100, step=1)
163
+ batch_size = gr.Slider(label="Batch size", elem_id="promptgen_batch_size", value=10, minimum=1, maximum=100, step=1)
164
+
165
+ with gr.Column():
166
+ with gr.Group(elem_id="promptgen_results_column"):
167
+ res = gr.HTML()
168
+ res_info = gr.HTML()
169
+
170
+ submit.click(
171
+ fn=generate(extra_outputs=['']),
172
+ _js="submit_promptgen",
173
+ inputs=[model_selection, model_selection, batch_count, batch_size, prompt, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p, ],
174
+ outputs=[res, res_info]
175
+ )
176
+
177
+ model_selection.change(
178
+ fn=model_selection_changed,
179
+ inputs=[model_selection],
180
+ outputs=[],
181
+ )
182
+
183
+
184
+ space.launch()