Mandar Patil commited on
Commit
38e8a10
1 Parent(s): 910fca9

initial commit

Browse files
Files changed (2) hide show
  1. app.py +33 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import BlenderbotTokenizer
5
+ from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, BlenderbotConfig
6
+ from transformers import BlenderbotTokenizerFast
7
+ import contextlib
8
+
9
+ #tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
10
+ #model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
11
+ #tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-3B")
12
+ mname = "facebook/blenderbot-3B"
13
+ #configuration = BlenderbotConfig.from_pretrained(mname)
14
+ tokenizer = BlenderbotTokenizerFast.from_pretrained(mname)
15
+ model = BlenderbotForConditionalGeneration.from_pretrained(mname)
16
+ #tokenizer = BlenderbotTokenizer.from_pretrained(mname)
17
+ #-----------new chat-----------
18
+ print(mname + 'model loaded')
19
+ def predict(input,history=[]):
20
+
21
+ history.append(input)
22
+
23
+ listToStr= '</s> <s>'.join([str(elem)for elem in history[len(history)-3:]])
24
+ #print('listToStr -->',str(listToStr))
25
+ input_ids = tokenizer([(listToStr)], return_tensors="pt",max_length=512,truncation=True)
26
+ next_reply_ids = model.generate(**input_ids,max_length=512, pad_token_id=tokenizer.eos_token_id)
27
+ response = tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0]
28
+ history.append(response)
29
+ response = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)] # convert to tuples of list
30
+ return response, history
31
+
32
+ demo = gr.Interface(fn=predict, inputs=["text",'state'], outputs=["chatbot",'state'])
33
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ gradio
3
+ transformers