vasu0508 commited on
Commit
a863b73
1 Parent(s): e321903

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -28
app.py CHANGED
@@ -58,33 +58,54 @@ model = AutoModelForCausalLM.from_pretrained(model_name)
58
 
59
  import gradio as gr
60
 
61
- def generate_text(text):
62
- step=-1
63
- while(True):
64
- step+=1
65
- detected_language=Detector(text,quiet=True).language.code
66
- translator=Translator(from_lang=detected_language,to_lang="en")
67
- translated_input=translator.translate(text)
68
-
69
- if text.lower().find("bye")!=-1:
70
- print(f">> Meena:> Bye Bye!")
71
- break;
72
- # encode the input and add end of string token
73
- input_ids = tokenizer.encode(translated_input+tokenizer.eos_token, return_tensors="pt")
74
- # concatenate new user input with chat history (if there is)
75
- bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1) if step > 0 else input_ids
76
- # generate a bot response
77
- chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id,do_sample=True,top_p=0.9,top_k=50,temperature=0.7,num_beams=5,no_repeat_ngram_size=2)
78
- #print the output
79
- output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
80
-
81
- translator=Translator(from_lang="en",to_lang=detected_language)
82
- translated_output=translator.translate(output)
83
-
84
- return translated_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- output_text=gr.Textbox()
87
- gr.Interface(generate_text,"textbox",output_text,title="Meena",
88
- description="Meena- A Multilingual Chatbot").launch(debug=False)
89
 
90
- #!gradio deploy
 
 
 
 
58
 
59
  import gradio as gr
60
 
61
+ with gr.Blocks() as meena:
62
+ chatbot = gr.Chatbot()
63
+ msg = gr.Textbox()
64
+ clear = gr.Button("Clear")
65
+ def set(chat_history_ids1):
66
+ global chat_history_ids
67
+ chat_history_ids=chat_history_ids1
68
+ def get():
69
+ return chat_history_ids
70
+ def set2(step1):
71
+ global step
72
+ step=step1
73
+ def get2():
74
+ return step
75
+ def generate_text(text,chat_history):
76
+ step=-1
77
+ if len(chat_history)==0:
78
+ step=-1
79
+ else:
80
+ step=get2()
81
+ step+=1
82
+ set2(step)
83
+ print(step)
84
+ if step!=0:
85
+ chat_history_ids=get()
86
+ detected_language=Detector(text,quiet=True).language.code
87
+ translator=Translator(from_lang=detected_language,to_lang="en")
88
+ translated_input=translator.translate(text)
89
+ # encode the input and add end of string token
90
+ input_ids=tokenizer.encode(translated_input+tokenizer.eos_token,return_tensors="pt")
91
+ # concatenate new user input with chat history (if there is)
92
+ bot_input_ids=torch.cat([chat_history_ids,input_ids],dim=-1) if step>0 else input_ids
93
+ # generate a bot response
94
+ chat_history_ids=model.generate(bot_input_ids,max_length=1000,pad_token_id=tokenizer.eos_token_id,do_sample=True,top_p=0.9,top_k=50,temperature=0.7,num_beams=5,no_repeat_ngram_size=2)
95
+ print(chat_history_ids)
96
+ set(chat_history_ids)
97
+ #print the output
98
+ output=tokenizer.decode(chat_history_ids[:,bot_input_ids.shape[-1]:][0],skip_special_tokens=True)
99
+ translator=Translator(from_lang="en",to_lang=detected_language)
100
+ translated_output=translator.translate(output)
101
+ chat_history.append((text,translated_output))
102
+ if step==5:
103
+ set(-1)
104
+ set2(-1)
105
 
106
+ return "",chat_history
 
 
107
 
108
+ msg.submit(generate_text, [msg, chatbot], [msg, chatbot])
109
+ clear.click(lambda: None, None, chatbot, queue=False)
110
+
111
+ meena.queue().launch()