Norod78 commited on
Commit
1ee42bd
1 Parent(s): 82ae8bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -4,6 +4,7 @@ os.system("pip install git+https://github.com/huggingface/transformers")
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
 
7
 
8
  tok = AutoTokenizer.from_pretrained("distilgpt2")
9
  model = AutoModelForCausalLM.from_pretrained("distilgpt2")
@@ -16,10 +17,10 @@ early_stop_pattern = tok.eos_token
16
  print(f'Early stop pattern = \"{early_stop_pattern}\"')
17
 
18
  def generate(text = ""):
19
- streamer = TextIteratorStreamer(tok)
20
  if len(text) == 0:
21
  text = " "
22
- inputs = tok([text], return_tensors="pt")
23
  generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.0, do_sample=True, top_k=40, top_p=0.97, max_new_tokens=128)
24
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
25
  thread.start()
 
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
7
+ import torch
8
 
9
  tok = AutoTokenizer.from_pretrained("distilgpt2")
10
  model = AutoModelForCausalLM.from_pretrained("distilgpt2")
 
17
  print(f'Early stop pattern = \"{early_stop_pattern}\"')
18
 
19
  def generate(text = ""):
20
+ streamer = TextIteratorStreamer(tok, timeout=10.)
21
  if len(text) == 0:
22
  text = " "
23
+ inputs = tok([text], return_tensors="pt").to(device)
24
  generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.0, do_sample=True, top_k=40, top_p=0.97, max_new_tokens=128)
25
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
26
  thread.start()