File size: 1,607 Bytes
38e8a10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import gradio as gr
from transformers import BlenderbotTokenizer
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, BlenderbotConfig
from transformers import BlenderbotTokenizerFast 
import contextlib

#tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
#model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
#tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-3B")
mname = "facebook/blenderbot-3B"
#configuration = BlenderbotConfig.from_pretrained(mname)
tokenizer = BlenderbotTokenizerFast.from_pretrained(mname)
model = BlenderbotForConditionalGeneration.from_pretrained(mname)
#tokenizer = BlenderbotTokenizer.from_pretrained(mname)
#-----------new chat-----------
print(mname + 'model loaded')
def predict(input,history=[]):
  
    history.append(input)
    
    listToStr= '</s> <s>'.join([str(elem)for elem in history[len(history)-3:]])
    #print('listToStr -->',str(listToStr))
    input_ids = tokenizer([(listToStr)], return_tensors="pt",max_length=512,truncation=True)
    next_reply_ids = model.generate(**input_ids,max_length=512, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0]
    history.append(response)
    response = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)]  # convert to tuples of list
    return response, history
	
demo = gr.Interface(fn=predict, inputs=["text",'state'], outputs=["chatbot",'state'])
demo.launch()