Timing0311 commited on
Commit
d655623
1 Parent(s): 81b7cc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -1,9 +1,9 @@
1
- from transformers import MT5ForConditionalGeneration, AutoTokenizer
2
  import gradio as gr
3
 
4
  trans_mdl = MT5ForConditionalGeneration.from_pretrained("K024/mt5-zh-ja-en-trimmed")
5
  trans_tokenizer = AutoTokenizer.from_pretrained("K024/mt5-zh-ja-en-trimmed")
6
-
7
 
8
  def translation_job(data):
9
  job = data[0]
@@ -16,9 +16,8 @@ def translation_job(data):
16
  job_map = dict(zip(job_key, job_value))
17
 
18
  input = job_map[job] + text
19
- enc = trans_tokenizer(input, return_tensor="pt")
20
- tokens = trans_mdl.generate(**enc)
21
- response = trans_tokenizer.batch_decode(tokens)
22
  return response
23
 
24
 
 
1
+ from transformers import MT5ForConditionalGeneration, AutoTokenizer, Text2TextGenerationPipeline
2
  import gradio as gr
3
 
4
  trans_mdl = MT5ForConditionalGeneration.from_pretrained("K024/mt5-zh-ja-en-trimmed")
5
  trans_tokenizer = AutoTokenizer.from_pretrained("K024/mt5-zh-ja-en-trimmed")
6
+ trans_pipe = Text2TextGenerationPipeline(model=trans_mdl, tokenizer=trans_tokenizer)
7
 
8
  def translation_job(data):
9
  job = data[0]
 
16
  job_map = dict(zip(job_key, job_value))
17
 
18
  input = job_map[job] + text
19
+ print(input)
20
+ response = trans_tokenizer.batch_decode(input, max_length=100, num_beams=4)
 
21
  return response
22
 
23