hugo1234 commited on
Commit
f71c47a
1 Parent(s): c33476c

Add application file

Browse files
Files changed (1) hide show
  1. app.py +629 -0
app.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import yaml
3
+ import gc
4
+ import copy
5
+ import time
6
+ from tenacity import RetryError
7
+ from tenacity import retry, stop_after_attempt, wait_fixed
8
+ import gradio as gr
9
+ import torch
10
+ from peft import PeftModel
11
+ from transformers import (
12
+ LLaMATokenizer,
13
+ LLaMAForCausalLM,
14
+ GenerationConfig,
15
+ AutoModelForCausalLM,
16
+ AutoModelForSeq2SeqLM,
17
+ AutoTokenizer,
18
+ LogitsProcessorList,
19
+ MinNewTokensLengthLogitsProcessor,
20
+ TemperatureLogitsWarper,
21
+ TopPLogitsWarper,
22
+ MinLengthLogitsProcessor
23
+ )
24
+
25
+ assert torch.cuda.is_available(), "Change the runtime type to GPU"
26
+
27
+ # constants
28
+ num_of_characters_to_keep = 1000
29
+
30
+ # regex
31
+ html_tag_pattern = re.compile(r"<.*?>")
32
+ multi_line_pattern = re.compile(r"\n+")
33
+ multi_space_pattern = re.compile(r"( )")
34
+ multi_br_tag_pattern = re.compile(re.compile(r'<br>\s*(<br>\s*)*'))
35
+
36
+ # repl is short for replacement
37
+ repl_linebreak = "\n"
38
+ repl_empty_str = ""
39
+
40
+ TITLE = "🦌 Stambecco 🇮🇹"
41
+
42
+ ABSTRACT = """
43
+ Stambecco is a Italian Instruction-following model based on the [LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) model. It comes in two versions: 7b and 13b parameters. It is trained on an Italian version of the [GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) dataset, a dataset of `GPT-4` generated instruction-following data.
44
+ This demo is intended to show and evaluate the conversational capabilities of the model.
45
+ For more information, please visit [the project's website](https://github.com/mchl-labs/stambecco).
46
+ NOTE: Too long input (context, instruction) will not be allowed. Please keep context < 500 and instruction < 150
47
+ """
48
+
49
+ BOTTOM_LINE = """
50
+ By default, this demo runs with streaming mode, but you can also run with dynamic batch generation model.
51
+ Stambecco is built on the same concept as Standford Alpaca project, but using LoRA it lets us train and inference on a smaller GPUs such as RTX4090 for 7B version. Also, we could build very small size of checkpoints on top of base models thanks to [🤗 transformers](https://huggingface.co/docs/transformers/index), [🤗 peft](https://github.com/huggingface/peft), and [bitsandbytes](https://github.com/TimDettmers/bitsandbytes/tree/main) libraries.
52
+ This demo currently runs 8Bit 7b version of the model.
53
+ """
54
+
55
+ DEFAULT_EXAMPLES = {
56
+ "Typical Questions": [
57
+ {
58
+ "title": "Parlami di Giulio Cesare.",
59
+ "examples": [
60
+ ["1", "Scrivi un articolo su Giulio Cesare"],
61
+ ["2", "Davvero?"],
62
+ ["3", "Quanto era ricco Giulio Cesare?"],
63
+ ["4", "Chi è stato il suo successore?"],
64
+ ]
65
+ },
66
+ {
67
+ "title": "Parigi",
68
+ "examples": [
69
+ ["1", "Scrivi un tema sulla città di Parigi"],
70
+ ["2", "Fai un elenco di 5 posti da visitare assolutamente"],
71
+ ["3", "Quali eventi importanti della Storia sono avvenuti a Parigi?"],
72
+ ["4", "Quale è il periodo migliore per visitare Parigi?"],
73
+ ]
74
+ },
75
+ {
76
+ "title": "Scrivi un programma in Python che stampi i primi 10 numeri di Fibonacci",
77
+ "examples": [
78
+ ["1", "Scrivi un programma in Python che stampi i primi 10 numeri di Fibonacci"],
79
+ ["2", "Potresti spiegarmi come funziona il codice?"],
80
+ ["3", "Cos'è la ricorsione?"],
81
+ ]
82
+ }
83
+ ],
84
+ }
85
+
86
+ SPECIAL_STRS = {
87
+ "continue": "continua",
88
+ "summarize": "Di cosa abbiamo discusso finora? Descrivi nella user's view."
89
+ }
90
+
91
+ PARENT_BLOCK_CSS = """
92
+ #col_container {
93
+ width: 95%;
94
+ margin-left: auto;
95
+ margin-right: auto;
96
+ }
97
+ #chatbot {
98
+ height: 500px;
99
+ overflow: auto;
100
+ }
101
+ """
102
+
103
+ def load_model(
104
+ base="decapoda-research/llama-7b-hf",
105
+ finetuned="mchl-labs/stambecco-7b-plus",
106
+ ):
107
+ tokenizer = LLaMATokenizer.from_pretrained(base)
108
+ tokenizer.pad_token_id = 0
109
+ tokenizer.padding_side = "left"
110
+
111
+ model = LLaMAForCausalLM.from_pretrained(
112
+ base,
113
+ load_in_8bit=True,
114
+ device_map="auto",
115
+ )
116
+ # model = PeftModel.from_pretrained(model, finetuned, device_map={'': 0})
117
+
118
+ model = PeftModel.from_pretrained(model, finetuned)
119
+ return model, tokenizer
120
+
121
+ def get_generation_config(path):
122
+ with open(path, 'rb') as f:
123
+ generation_config = yaml.safe_load(f.read())
124
+
125
+ return GenerationConfig(**generation_config["generation_config"])
126
+
127
+ def generate_prompt(prompt, histories, ctx=None, partial=False):
128
+ convs = f"""Di seguito è riportata una cronologia delle istruzioni che descrivono le tasks, abbinate a un input che fornisce ulteriore contesto. Scrivi una risposta che completi adeguatamente la richiesta ricordando la cronologia della conversazione.
129
+
130
+ """
131
+
132
+ if ctx is not None:
133
+ convs = f"""### Input: {ctx}
134
+ """
135
+
136
+ sub_convs = ""
137
+ start_idx = 0
138
+
139
+ for idx, history in enumerate(histories):
140
+ history_prompt = history[0]
141
+ history_response = history[1]
142
+ if history_response == "✅ Riepilogo della conversazione effettuato e impostato come contesto" or history_prompt == SPECIAL_STRS["summarize"]:
143
+ start_idx = idx
144
+
145
+ # drop the previous conversations if user has summarized
146
+ for history in histories[start_idx if start_idx == 0 else start_idx+1:]:
147
+ history_prompt = history[0]
148
+ history_response = history[1]
149
+
150
+ history_response = history_response.replace("<br>", "\n")
151
+ history_response = re.sub(
152
+ html_tag_pattern, repl_empty_str, history_response
153
+ )
154
+
155
+ sub_convs = sub_convs + f"""### Istruzione: {history_prompt}
156
+ ### Risposta: {history_response}
157
+ """
158
+
159
+ sub_convs = sub_convs + f"""### Istruzione: {prompt}
160
+ ### Risposta:"""
161
+
162
+ convs = convs + sub_convs
163
+ return sub_convs if partial else convs, len(sub_convs)
164
+
165
+ def common_post_process(original_str):
166
+ original_str = re.sub(
167
+ multi_line_pattern, repl_linebreak, original_str
168
+ )
169
+ return original_str
170
+
171
+ def post_process_stream(bot_response):
172
+ # sometimes model spits out text containing
173
+ # "### Risposta:" and "### Istruzione: -> in this case, we want to stop generating
174
+ if "### Risposta:" in bot_response or "### Input:" in bot_response:
175
+ bot_response = bot_response.replace("### Risposta:", '').replace("### Input:", '').strip()
176
+ return bot_response, True
177
+
178
+ return common_post_process(bot_response), False
179
+
180
+ def post_process_batch(bot_response):
181
+ bot_response = bot_response.split("### Risposta:")[-1].strip()
182
+ return common_post_process(bot_response)
183
+
184
+ def post_processes_batch(bot_responses):
185
+ return [post_process_batch(r) for r in bot_responses]
186
+
187
+ def get_output_batch(
188
+ model, tokenizer, prompts, generation_config
189
+ ):
190
+ if len(prompts) == 1:
191
+ encoding = tokenizer(prompts, return_tensors="pt")
192
+ input_ids = encoding["input_ids"].cuda()
193
+ generated_id = model.generate(
194
+ input_ids=input_ids,
195
+ generation_config=generation_config,
196
+ max_new_tokens=256
197
+ )
198
+
199
+ decoded = tokenizer.batch_decode(generated_id)
200
+ del input_ids, generated_id
201
+ torch.cuda.empty_cache()
202
+ return decoded
203
+ else:
204
+ encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda')
205
+ generated_ids = model.generate(
206
+ **encodings,
207
+ generation_config=generation_config,
208
+ max_new_tokens=256
209
+ )
210
+
211
+ decoded = tokenizer.batch_decode(generated_ids)
212
+ del encodings, generated_ids
213
+ torch.cuda.empty_cache()
214
+ return decoded
215
+
216
+
217
+ # StreamModel is borrowed from basaran project
218
+ # please find more info about it -> https://github.com/hyperonym/basaran
219
+ class StreamModel:
220
+ """StreamModel wraps around a language model to provide stream decoding."""
221
+
222
+ def __init__(self, model, tokenizer):
223
+ super().__init__()
224
+ self.model = model
225
+ self.tokenizer = tokenizer
226
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
227
+
228
+ self.processor = LogitsProcessorList()
229
+ self.processor.append(TemperatureLogitsWarper(0.9))
230
+ self.processor.append(TopPLogitsWarper(0.75))
231
+
232
+
233
+ def __call__(
234
+ self,
235
+ prompt,
236
+ min_tokens=0,
237
+ max_tokens=16,
238
+ temperature=1.0,
239
+ top_p=1.0,
240
+ n=1,
241
+ logprobs=0,
242
+ ):
243
+ """Create a completion stream for the provided prompt."""
244
+ input_ids = self.tokenize(prompt)
245
+ logprobs = max(logprobs, 0)
246
+
247
+ # bigger than 1
248
+ chunk_size = 2
249
+ chunk_count = 0
250
+
251
+ # Generate completion tokens.
252
+ final_tokens = torch.empty(0)
253
+
254
+ for tokens in self.generate(
255
+ input_ids[None, :].repeat(n, 1),
256
+ logprobs=logprobs,
257
+ min_new_tokens=min_tokens,
258
+ max_new_tokens=max_tokens,
259
+ temperature=temperature,
260
+ top_p=top_p,
261
+ ):
262
+ if chunk_count < chunk_size:
263
+ chunk_count = chunk_count + 1
264
+
265
+ final_tokens = torch.cat((final_tokens, tokens.to("cpu")))
266
+
267
+ if chunk_count == chunk_size-1:
268
+ chunk_count = 0
269
+ yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
270
+
271
+ if chunk_count > 0:
272
+ yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
273
+
274
+ del final_tokens, input_ids
275
+ if self.device == "cuda":
276
+ torch.cuda.empty_cache()
277
+
278
+ def _infer(self, model_fn, **kwargs):
279
+ with torch.inference_mode():
280
+ return model_fn(**kwargs)
281
+
282
+ def tokenize(self, text):
283
+ """Tokenize a string into a tensor of token IDs."""
284
+ batch = self.tokenizer.encode(text, return_tensors="pt")
285
+ return batch[0].to(self.device)
286
+
287
+ def generate(self, input_ids, logprobs=0, **kwargs):
288
+ """Generate a stream of predicted tokens using the language model."""
289
+
290
+ # Store the original batch size and input length.
291
+ batch_size = input_ids.shape[0]
292
+ input_length = input_ids.shape[-1]
293
+
294
+ # Separate model arguments from generation config.
295
+ config = self.model.generation_config
296
+ config = copy.deepcopy(config)
297
+ kwargs = config.update(**kwargs)
298
+ kwargs["output_attentions"] = False
299
+ kwargs["output_hidden_states"] = False
300
+ kwargs["use_cache"] = True
301
+
302
+ # Collect special token IDs.
303
+ pad_token_id = config.pad_token_id
304
+ bos_token_id = config.bos_token_id
305
+ eos_token_id = config.eos_token_id
306
+ if isinstance(eos_token_id, int):
307
+ eos_token_id = [eos_token_id]
308
+ if pad_token_id is None and eos_token_id is not None:
309
+ pad_token_id = eos_token_id[0]
310
+
311
+ # Generate from eos if no input is specified.
312
+ if input_length == 0:
313
+ input_ids = input_ids.new_ones((batch_size, 1)).long()
314
+ if eos_token_id is not None:
315
+ input_ids = input_ids * eos_token_id[0]
316
+ input_length = 1
317
+
318
+ # Keep track of which sequences are already finished.
319
+ unfinished = input_ids.new_ones(batch_size)
320
+
321
+ # Start auto-regressive generation.
322
+ while True:
323
+ inputs = self.model.prepare_inputs_for_generation(
324
+ input_ids, **kwargs
325
+ ) # noqa: E501
326
+
327
+ outputs = self._infer(
328
+ self.model,
329
+ **inputs,
330
+ # return_dict=True,
331
+ output_attentions=False,
332
+ output_hidden_states=False,
333
+ )
334
+
335
+ # Pre-process the probability distribution of the next tokens.
336
+ logits = outputs.logits[:, -1, :]
337
+ with torch.inference_mode():
338
+ logits = self.processor(input_ids, logits)
339
+ probs = torch.nn.functional.softmax(logits, dim=-1)
340
+
341
+ # Select deterministic or stochastic decoding strategy.
342
+ if (config.top_p is not None and config.top_p <= 0) or (
343
+ config.temperature is not None and config.temperature <= 0
344
+ ):
345
+ tokens = torch.argmax(probs, dim=-1)[:, None]
346
+ else:
347
+ tokens = torch.multinomial(probs, num_samples=1)
348
+
349
+ tokens = tokens.squeeze(1)
350
+
351
+ # Finished sequences should have their next token be a padding.
352
+ if pad_token_id is not None:
353
+ tokens = tokens * unfinished + pad_token_id * (1 - unfinished)
354
+
355
+ # Append selected tokens to the inputs.
356
+ input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1)
357
+
358
+ # Mark sequences with eos tokens as finished.
359
+ if eos_token_id is not None:
360
+ not_eos = sum(tokens != i for i in eos_token_id)
361
+ unfinished = unfinished.mul(not_eos.long())
362
+
363
+ # Set status to -1 if exceeded the max length.
364
+ status = unfinished.clone()
365
+ if input_ids.shape[-1] - input_length >= config.max_new_tokens:
366
+ status = 0 - status
367
+
368
+ # Yield predictions and status.
369
+ yield tokens
370
+
371
+ # Stop when finished or exceeded the max length.
372
+ if status.max() <= 0:
373
+ break
374
+
375
+ generation_config = get_generation_config(
376
+ "./generation_config_default.yaml"
377
+ )
378
+
379
+ model, tokenizer = load_model(
380
+ # base="decapoda-research/llama-13b-hf",
381
+ # finetuned="mchl-labs/stambecco-13b-plus",
382
+ )
383
+
384
+ stream_model = StreamModel(model, tokenizer)
385
+
386
+ def chat_stream(
387
+ context,
388
+ instruction,
389
+ state_chatbot,
390
+ ):
391
+ if len(context) > 1000 or len(instruction) > 300:
392
+ raise gr.Error("Context or prompt is too long!")
393
+
394
+ bot_summarized_response = ''
395
+ # user input should be appropriately formatted (don't be confused by the function name)
396
+ instruction_display = instruction
397
+ instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context)
398
+
399
+ if conv_length > num_of_characters_to_keep:
400
+ instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context, partial=True)[0]
401
+
402
+ state_chatbot = state_chatbot + [
403
+ (
404
+ None,
405
+ "![](https://s2.gifyu.com/images/icons8-loading-circle.gif) Conversazione troppo lunga, sto riassumendo..."
406
+ )
407
+ ]
408
+ yield (state_chatbot, state_chatbot, context)
409
+
410
+ bot_summarized_response = get_output_batch(
411
+ model, tokenizer, [instruction_prompt], generation_config
412
+ )[0]
413
+ bot_summarized_response = bot_summarized_response.split("### Risposta:")[-1].strip()
414
+
415
+ state_chatbot[-1] = (
416
+ None,
417
+ "✅ Riepilogo della conversazione effettuato e impostato come contesto"
418
+ )
419
+ print(f"bot_summarized_response: {bot_summarized_response}")
420
+ yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip())
421
+
422
+ instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0]
423
+
424
+ bot_response = stream_model(
425
+ instruction_prompt,
426
+ max_tokens=256,
427
+ temperature=1,
428
+ top_p=0.9
429
+ )
430
+
431
+ instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display
432
+ state_chatbot = state_chatbot + [(instruction_display, None)]
433
+ yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip())
434
+
435
+ prev_index = 0
436
+ agg_tokens = ""
437
+ cutoff_idx = 0
438
+ for tokens in bot_response:
439
+ tokens = tokens.strip()
440
+ cur_token = tokens[prev_index:]
441
+
442
+ if "#" in cur_token and agg_tokens == "":
443
+ cutoff_idx = tokens.find("#")
444
+ agg_tokens = tokens[cutoff_idx:]
445
+
446
+ if agg_tokens != "":
447
+ if len(agg_tokens) < len("### Istruzione:") :
448
+ agg_tokens = agg_tokens + cur_token
449
+ elif len(agg_tokens) >= len("### Istruzione:"):
450
+ if tokens.find("### Istruzione:") > -1:
451
+ processed_response, _ = post_process_stream(tokens[:tokens.find("### Istruzione:")].strip())
452
+
453
+ state_chatbot[-1] = (
454
+ instruction_display,
455
+ processed_response
456
+ )
457
+ yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())
458
+ break
459
+ else:
460
+ agg_tokens = ""
461
+ cutoff_idx = 0
462
+
463
+ if agg_tokens == "":
464
+ processed_response, to_exit = post_process_stream(tokens)
465
+ state_chatbot[-1] = (instruction_display, processed_response)
466
+ yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())
467
+
468
+ if to_exit:
469
+ break
470
+
471
+ prev_index = len(tokens)
472
+
473
+ yield (
474
+ state_chatbot,
475
+ state_chatbot,
476
+ f"{context} {bot_summarized_response}".strip()
477
+ )
478
+
479
+
480
+ def chat_batch(
481
+ contexts,
482
+ instructions,
483
+ state_chatbots,
484
+ ):
485
+ state_results = []
486
+ ctx_results = []
487
+
488
+ instruct_prompts = [
489
+ generate_prompt(instruct, histories, ctx)
490
+ for ctx, instruct, histories in zip(contexts, instructions, state_chatbots)
491
+ ]
492
+
493
+ bot_responses = get_output_batch(
494
+ model, tokenizer, instruct_prompts, generation_config
495
+ )
496
+ bot_responses = post_processes_batch(bot_responses)
497
+
498
+ for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots):
499
+ new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)]
500
+ ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx)
501
+ state_results.append(new_state_chatbot)
502
+
503
+ return (state_results, state_results, ctx_results)
504
+
505
+ def reset_textbox():
506
+ return gr.Textbox.update(value='')
507
+
508
+ def reset_everything(
509
+ context_txtbox,
510
+ instruction_txtbox,
511
+ state_chatbot):
512
+
513
+ state_chatbot = []
514
+
515
+ return (
516
+ state_chatbot,
517
+ state_chatbot,
518
+ gr.Textbox.update(value=''),
519
+ gr.Textbox.update(value=''),
520
+ )
521
+
522
+ with gr.Blocks(css=PARENT_BLOCK_CSS) as demo:
523
+ state_chatbot = gr.State([])
524
+
525
+ with gr.Column(elem_id='col_container'):
526
+ gr.Markdown(f"## {TITLE}\n\n\n{ABSTRACT}")
527
+
528
+ with gr.Accordion("Context Setting", open=False):
529
+ context_txtbox = gr.Textbox(placeholder="Surrounding information to AI", label="Enter Context")
530
+ hidden_txtbox = gr.Textbox(placeholder="", label="Order", visible=False)
531
+
532
+ chatbot = gr.Chatbot(elem_id='chatbot', label="Stambecco")
533
+ instruction_txtbox = gr.Textbox(placeholder="What do you want to say to AI?", label="Instruction")
534
+ with gr.Row():
535
+ cancel_btn = gr.Button(value="Cancel")
536
+ reset_btn = gr.Button(value="Reset")
537
+
538
+ with gr.Accordion("Helper Buttons", open=False):
539
+ gr.Markdown(f"`Continue` lets AI to complete the previous incomplete answers. `Summarize` lets AI to summarize the conversations so far.")
540
+ continue_txtbox = gr.Textbox(value=SPECIAL_STRS["continue"], visible=False)
541
+ summrize_txtbox = gr.Textbox(value=SPECIAL_STRS["summarize"], visible=False)
542
+
543
+ continue_btn = gr.Button(value="Continue")
544
+ summarize_btn = gr.Button(value="Summarize")
545
+
546
+ gr.Markdown("#### Examples")
547
+ for _, (category, examples) in enumerate(DEFAULT_EXAMPLES.items()):
548
+ with gr.Accordion(category, open=False):
549
+ if category == "Identity":
550
+ for item in examples:
551
+ with gr.Accordion(item["title"], open=False):
552
+ gr.Examples(
553
+ examples=item["examples"],
554
+ inputs=[
555
+ hidden_txtbox, context_txtbox, instruction_txtbox
556
+ ],
557
+ label=None
558
+ )
559
+ else:
560
+ for item in examples:
561
+ with gr.Accordion(item["title"], open=False):
562
+ gr.Examples(
563
+ examples=item["examples"],
564
+ inputs=[
565
+ hidden_txtbox, instruction_txtbox
566
+ ],
567
+ label=None
568
+ )
569
+
570
+ gr.Markdown(f"{BOTTOM_LINE}")
571
+
572
+
573
+ send_event = instruction_txtbox.submit(
574
+ chat_stream,
575
+ [context_txtbox, instruction_txtbox, state_chatbot],
576
+ [state_chatbot, chatbot, context_txtbox],
577
+ )
578
+ reset_event = instruction_txtbox.submit(
579
+ reset_textbox,
580
+ [],
581
+ [instruction_txtbox],
582
+ )
583
+
584
+ continue_event = continue_btn.click(
585
+ chat_stream,
586
+ [context_txtbox, continue_txtbox, state_chatbot],
587
+ [state_chatbot, chatbot, context_txtbox],
588
+ )
589
+ reset_continue_event = continue_btn.click(
590
+ reset_textbox,
591
+ [],
592
+ [instruction_txtbox],
593
+ )
594
+
595
+ summarize_event = summarize_btn.click(
596
+ chat_stream,
597
+ [context_txtbox, summrize_txtbox, state_chatbot],
598
+ [state_chatbot, chatbot, context_txtbox],
599
+ )
600
+ summarize_reset_event = summarize_btn.click(
601
+ reset_textbox,
602
+ [],
603
+ [instruction_txtbox],
604
+ )
605
+
606
+ cancel_btn.click(
607
+ None, None, None,
608
+ cancels=[
609
+ send_event, continue_event, summarize_event
610
+ ]
611
+ )
612
+
613
+ reset_btn.click(
614
+ reset_everything,
615
+ [context_txtbox, instruction_txtbox, state_chatbot],
616
+ [state_chatbot, chatbot, context_txtbox, instruction_txtbox],
617
+ cancels=[
618
+ send_event, continue_event, summarize_event
619
+ ]
620
+ )
621
+
622
+ demo.queue(
623
+ concurrency_count=1,
624
+ max_size=100,
625
+ ).launch(
626
+ max_threads=5,
627
+ server_name="0.0.0.0",
628
+ share=True
629
+ )