Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""Meena_A_Multilingual_Chatbot (1).ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1-IfUcnDUppyMArHonc_iesEcN2gSKU-j | |
""" | |
#!pip3 install transformers | |
#!pip install -q translate | |
#!pip install polyglot | |
#!pip install Pyicu | |
#!pip install Morfessor | |
#!pip install pycld2 | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from translate import Translator | |
from polyglot.detect import Detector | |
# model_name = "microsoft/DialoGPT-large" | |
model_name = "microsoft/DialoGPT-large" | |
# model_name = "microsoft/DialoGPT-small" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# # chatting 5 times with nucleus sampling & tweaking temperature | |
# step=-1 | |
# while(True): | |
# step+=1 | |
# # take user input | |
# text = input(">> You:>") | |
# detected_language=Detector(text,quiet=True).language.code | |
# translator=Translator(from_lang=detected_language,to_lang="en") | |
# translated_input=translator.translate(text) | |
# print(translated_input) | |
# if text.lower().find("bye")!=-1: | |
# print(f">> Meena:> Bye Bye!") | |
# break; | |
# # encode the input and add end of string token | |
# input_ids = tokenizer.encode(translated_input+tokenizer.eos_token, return_tensors="pt") | |
# # concatenate new user input with chat history (if there is) | |
# bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1) if step > 0 else input_ids | |
# # generate a bot response | |
# chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id,do_sample=True,top_p=0.9,top_k=50,temperature=0.7,num_beams=5,no_repeat_ngram_size=2) | |
# #print the output | |
# output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) | |
# print(output) | |
# translator=Translator(from_lang="en",to_lang=detected_language) | |
# translated_output=translator.translate(output) | |
# print(f">> Meena:> {translated_output}") | |
#!pip install gradio | |
import gradio as gr | |
with gr.Blocks() as meena: | |
chatbot = gr.Chatbot(label="Meena- A Multilingual Chatbot") | |
msg = gr.Textbox(label="You") | |
clear = gr.Button("Clear") | |
def set(chat_history_ids1): | |
global chat_history_ids | |
chat_history_ids=chat_history_ids1 | |
def get(): | |
return chat_history_ids | |
def set2(step1): | |
global step | |
step=step1 | |
def get2(): | |
return step | |
def generate_text(text,chat_history): | |
step=-1 | |
if len(chat_history)==0: | |
step=-1 | |
else: | |
step=get2() | |
step+=1 | |
set2(step) | |
print(step) | |
if step!=0: | |
chat_history_ids=get() | |
if text.isdigit(): | |
detected_language='en' | |
else: | |
detected_language=Detector(text,quiet=True).language.code | |
translator=Translator(from_lang=detected_language,to_lang="en") | |
translated_input=translator.translate(text) | |
# encode the input and add end of string token | |
input_ids=tokenizer.encode(translated_input+tokenizer.eos_token,return_tensors="pt") | |
# concatenate new user input with chat history (if there is) | |
bot_input_ids=torch.cat([chat_history_ids,input_ids],dim=-1) if step>0 else input_ids | |
# generate a bot response | |
chat_history_ids=model.generate(bot_input_ids,max_length=1000,pad_token_id=tokenizer.eos_token_id,do_sample=True,top_p=0.9,top_k=50,temperature=0.7,num_beams=5,no_repeat_ngram_size=2) | |
print(chat_history_ids) | |
set(chat_history_ids) | |
#print the output | |
output=tokenizer.decode(chat_history_ids[:,bot_input_ids.shape[-1]:][0],skip_special_tokens=True) | |
translator=Translator(from_lang="en",to_lang=detected_language) | |
translated_output=translator.translate(output) | |
chat_history.append((text,translated_output)) | |
if step==5: | |
set(-1) | |
set2(-1) | |
return "",chat_history | |
msg.submit(generate_text, [msg, chatbot], [msg, chatbot]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
meena.queue().launch() |