tyoyo commited on
Commit
d9fa225
1 Parent(s): 7a773a3

入力が空文字, ハイパラがintのときのバグ修正 (#1)

Browse files

- fix: 入力が空文字, ハイパラがintのときのバグ修正 (782e1e8de7a654ad4c2fd74277f0c4081b3c6c98)

Files changed (1) hide show
  1. app.py +32 -12
app.py CHANGED
@@ -100,7 +100,17 @@ def generate(
100
  raise ValueError
101
 
102
  history = history_with_input[:-1]
103
- generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k, do_sample, repetition_penalty)
 
 
 
 
 
 
 
 
 
 
104
  try:
105
  first_response = next(generator)
106
  yield history + [(message, first_response)]
@@ -130,7 +140,12 @@ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
130
  def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
131
  input_token_length = get_input_token_length(message, chat_history, system_prompt)
132
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
133
- raise gr.Error(f'合計対話長が長すぎます ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})。「🗑️ これまでの出力を消す」ボタンを押してから再実行してください。')
 
 
 
 
 
134
 
135
 
136
  def convert_history_to_str(history: list[tuple[str, str]]) -> str:
@@ -360,6 +375,11 @@ graphvizで、AからB、BからC、CからAに有向エッジが生えている
360
  api_name=False,
361
  queue=False,
362
  ).then(
 
 
 
 
 
363
  fn=display_input,
364
  inputs=[saved_input, chatbot],
365
  outputs=chatbot,
@@ -373,11 +393,6 @@ graphvizで、AからB、BからC、CからAに有向エッジが生えている
373
  fn=output_log,
374
  inputs=[chatbot, uuid_list],
375
  ).then(
376
- fn=check_input_token_length,
377
- inputs=[saved_input, chatbot, system_prompt],
378
- api_name=False,
379
- queue=False,
380
- ).success(
381
  fn=generate,
382
  inputs=[
383
  saved_input,
@@ -412,6 +427,11 @@ graphvizで、AからB、BからC、CからAに有向エッジが生えている
412
  api_name=False,
413
  queue=False,
414
  ).then(
 
 
 
 
 
415
  fn=display_input,
416
  inputs=[saved_input, chatbot],
417
  outputs=chatbot,
@@ -424,11 +444,6 @@ graphvizで、AからB、BからC、CからAに有向エッジが生えている
424
  ).then(
425
  fn=output_log,
426
  inputs=[chatbot, uuid_list],
427
- ).then(
428
- fn=check_input_token_length,
429
- inputs=[saved_input, chatbot, system_prompt],
430
- api_name=False,
431
- queue=False,
432
  ).success(
433
  fn=generate,
434
  inputs=[
@@ -464,6 +479,11 @@ graphvizで、AからB、BからC、CからAに有向エッジが生えている
464
  api_name=False,
465
  queue=False,
466
  ).then(
 
 
 
 
 
467
  fn=display_input,
468
  inputs=[saved_input, chatbot],
469
  outputs=chatbot,
 
100
  raise ValueError
101
 
102
  history = history_with_input[:-1]
103
+ generator = run(
104
+ message,
105
+ history,
106
+ system_prompt,
107
+ max_new_tokens,
108
+ float(temperature),
109
+ float(top_p),
110
+ top_k,
111
+ do_sample,
112
+ float(repetition_penalty),
113
+ )
114
  try:
115
  first_response = next(generator)
116
  yield history + [(message, first_response)]
 
140
  def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
141
  input_token_length = get_input_token_length(message, chat_history, system_prompt)
142
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
143
+ raise gr.Error(
144
+ f"合計対話長が長すぎます ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})。入力文章を短くするか、「🗑️ これまでの出力を消す」ボタンを押してから再実行してください。"
145
+ )
146
+
147
+ if len(message) <= 0:
148
+ raise gr.Error("入力が空です。1文字以上の文字列を入力してください。")
149
 
150
 
151
  def convert_history_to_str(history: list[tuple[str, str]]) -> str:
 
375
  api_name=False,
376
  queue=False,
377
  ).then(
378
+ fn=check_input_token_length,
379
+ inputs=[saved_input, chatbot, system_prompt],
380
+ api_name=False,
381
+ queue=False,
382
+ ).success(
383
  fn=display_input,
384
  inputs=[saved_input, chatbot],
385
  outputs=chatbot,
 
393
  fn=output_log,
394
  inputs=[chatbot, uuid_list],
395
  ).then(
 
 
 
 
 
396
  fn=generate,
397
  inputs=[
398
  saved_input,
 
427
  api_name=False,
428
  queue=False,
429
  ).then(
430
+ fn=check_input_token_length,
431
+ inputs=[saved_input, chatbot, system_prompt],
432
+ api_name=False,
433
+ queue=False,
434
+ ).success(
435
  fn=display_input,
436
  inputs=[saved_input, chatbot],
437
  outputs=chatbot,
 
444
  ).then(
445
  fn=output_log,
446
  inputs=[chatbot, uuid_list],
 
 
 
 
 
447
  ).success(
448
  fn=generate,
449
  inputs=[
 
479
  api_name=False,
480
  queue=False,
481
  ).then(
482
+ fn=check_input_token_length,
483
+ inputs=[saved_input, chatbot, system_prompt],
484
+ api_name=False,
485
+ queue=False,
486
+ ).success(
487
  fn=display_input,
488
  inputs=[saved_input, chatbot],
489
  outputs=chatbot,