ipvikas commited on
Commit
f48a3e2
1 Parent(s): 67c53ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -1
app.py CHANGED
@@ -289,8 +289,26 @@ model.eval()
289
  bot_name = "WeASK"
290
 
291
  ###removed
 
292
 
 
 
 
 
 
 
 
 
 
 
 
293
  def get_response(input_text):
 
 
 
 
 
 
294
  #print("Let's chat! (type 'quit' to exit)")
295
  #while True:
296
  # sentence = "do you use credit cards?"
@@ -303,7 +321,7 @@ def get_response(input_text):
303
  #if sentence== "quit":
304
  #break
305
 
306
- sentence= tokenize(input_text)
307
  X = bag_of_words(sentence, all_words)
308
  X = X.reshape(1, X.shape[0])
309
  X = torch.from_numpy(X).to(device)
 
289
  bot_name = "WeASK"
290
 
291
  ###removed
292
+ from transformers import MBartForConditionalGeneration, MBart50Tokenizer
293
 
294
+ def download_model():
295
+ model_name = "facebook/mbart-large-50-many-to-many-mmt"
296
+ model = MBartForConditionalGeneration.from_pretrained(model_name)
297
+ tokenizer = MBart50Tokenizer.from_pretrained(model_name)
298
+ return model, tokenizer
299
+
300
+ model, tokenizer = download_model()
301
+
302
+
303
+
304
+ ################################
305
  def get_response(input_text):
306
+ model_inputs = tokenizer(input_text, return_tensors="pt")
307
+ generated_tokens = model.generate(**model_inputs,forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
308
+ translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
309
+
310
+
311
+
312
  #print("Let's chat! (type 'quit' to exit)")
313
  #while True:
314
  # sentence = "do you use credit cards?"
 
321
  #if sentence== "quit":
322
  #break
323
 
324
+ sentence= tokenize(translation)
325
  X = bag_of_words(sentence, all_words)
326
  X = X.reshape(1, X.shape[0])
327
  X = torch.from_numpy(X).to(device)