Ventsislav Muchinov commited on
Commit
3bc6adc
1 Parent(s): 0c0065d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -47,11 +47,17 @@ def generate(
47
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
48
  input_ids = input_ids.to(model.device)
49
 
 
 
 
 
 
50
  streamer = TextIteratorStreamer(tokenizer, timeout=300.0, skip_prompt=True, skip_special_tokens=True)
51
  generate_kwargs = dict(
52
  {"input_ids": input_ids},
53
  streamer=streamer,
54
  max_new_tokens=max_new_tokens,
 
55
  do_sample=True,
56
  top_p=top_p,
57
  top_k=top_k,
 
47
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
48
  input_ids = input_ids.to(model.device)
49
 
50
+ terminators = [
51
+ tokenizer.eos_token_id,
52
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
53
+ ]
54
+
55
  streamer = TextIteratorStreamer(tokenizer, timeout=300.0, skip_prompt=True, skip_special_tokens=True)
56
  generate_kwargs = dict(
57
  {"input_ids": input_ids},
58
  streamer=streamer,
59
  max_new_tokens=max_new_tokens,
60
+ eos_token_id=terminators,
61
  do_sample=True,
62
  top_p=top_p,
63
  top_k=top_k,