davanstrien HF staff commited on
Commit
aac03a6
1 Parent(s): 86d577a

fixes for nltk and transformers updates (#6)

Browse files

- fixes for nltk and transformers updates (d5f8a9a2c51772f299cc8c6fb8a90ca77732c184)
- Update app.py (7798d43d37bfc02159a6ad1b6a805d8cd8c2aea9)

Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -6,7 +6,7 @@ import platform
6
  import torch
7
  import nltk
8
 
9
- nltk.download("punkt")
10
 
11
  REMOVED_TARGET_LANGUAGES = {"Ligurian", "Lombard", "Sicilian"}
12
 
@@ -55,7 +55,7 @@ def translate(text: str, src_lang: str, tgt_lang: str):
55
  )
56
  translated_chunk = model.generate(
57
  input_ids=torch.tensor([input_tokens]).to(device),
58
- forced_bos_token_id=tokenizer.lang_code_to_id[code_mapping[tgt_lang]],
59
  max_length=len(input_tokens) + 50,
60
  num_return_sequences=1,
61
  )
@@ -93,4 +93,4 @@ with gr.Blocks() as demo:
93
  inputs=[input_text, src_lang, target_lang],
94
  outputs=output,
95
  )
96
- demo.launch()
 
6
  import torch
7
  import nltk
8
 
9
+ nltk.download("punkt_tab")
10
 
11
  REMOVED_TARGET_LANGUAGES = {"Ligurian", "Lombard", "Sicilian"}
12
 
 
55
  )
56
  translated_chunk = model.generate(
57
  input_ids=torch.tensor([input_tokens]).to(device),
58
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(code_mapping[tgt_lang]),
59
  max_length=len(input_tokens) + 50,
60
  num_return_sequences=1,
61
  )
 
93
  inputs=[input_text, src_lang, target_lang],
94
  outputs=output,
95
  )
96
+ demo.launch()