Update app.py
Browse files
app.py
CHANGED
@@ -289,8 +289,26 @@ model.eval()
|
|
289 |
bot_name = "WeASK"
|
290 |
|
291 |
###removed
|
|
|
292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
def get_response(input_text):
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
#print("Let's chat! (type 'quit' to exit)")
|
295 |
#while True:
|
296 |
# sentence = "do you use credit cards?"
|
@@ -303,7 +321,7 @@ def get_response(input_text):
|
|
303 |
#if sentence== "quit":
|
304 |
#break
|
305 |
|
306 |
-
sentence= tokenize(
|
307 |
X = bag_of_words(sentence, all_words)
|
308 |
X = X.reshape(1, X.shape[0])
|
309 |
X = torch.from_numpy(X).to(device)
|
|
|
289 |
bot_name = "WeASK"
|
290 |
|
291 |
###removed
|
292 |
+
from transformers import MBartForConditionalGeneration, MBart50Tokenizer
|
293 |
|
294 |
+
def download_model():
|
295 |
+
model_name = "facebook/mbart-large-50-many-to-many-mmt"
|
296 |
+
model = MBartForConditionalGeneration.from_pretrained(model_name)
|
297 |
+
tokenizer = MBart50Tokenizer.from_pretrained(model_name)
|
298 |
+
return model, tokenizer
|
299 |
+
|
300 |
+
model, tokenizer = download_model()
|
301 |
+
|
302 |
+
|
303 |
+
|
304 |
+
################################
|
305 |
def get_response(input_text):
|
306 |
+
model_inputs = tokenizer(input_text, return_tensors="pt")
|
307 |
+
generated_tokens = model.generate(**model_inputs,forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
|
308 |
+
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
#print("Let's chat! (type 'quit' to exit)")
|
313 |
#while True:
|
314 |
# sentence = "do you use credit cards?"
|
|
|
321 |
#if sentence== "quit":
|
322 |
#break
|
323 |
|
324 |
+
sentence= tokenize(translation)
|
325 |
X = bag_of_words(sentence, all_words)
|
326 |
X = X.reshape(1, X.shape[0])
|
327 |
X = torch.from_numpy(X).to(device)
|