kookoobau commited on
Commit
07d3b7a
1 Parent(s): d2a29a1
Files changed (2) hide show
  1. app.py +4 -3
  2. myLLM.py +18 -0
app.py CHANGED
@@ -1,13 +1,14 @@
1
  from langchain.chains import ConversationChain
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
 
4
 
5
- # Define the LangChain chat agent
6
  model_name = "microsoft/DialoGPT-medium"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
9
 
10
- agent = ConversationChain(model=model, tokenizer=tokenizer)
11
 
12
  # Define the Gradio interface
13
  def chatbot_interface(input_text):
 
1
  from langchain.chains import ConversationChain
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
+ from myLLM import AutoModelLanguageModel
5
 
 
6
  model_name = "microsoft/DialoGPT-medium"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ llm = AutoModelLanguageModel(model_name)
9
+
10
+ agent = ConversationChain(llm=llm)
11
 
 
12
 
13
  # Define the Gradio interface
14
  def chatbot_interface(input_text):
myLLM.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ from langchain.chains import LanguageModel
3
+
4
+ class AutoModelLanguageModel(LanguageModel):
5
+ def __init__(self, model_name_or_path):
6
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
7
+ self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
8
+
9
+ def generate_prompt(self, input_text, history):
10
+ inputs = self.tokenizer.encode(input_text + self.tokenizer.eos_token, return_tensors="pt")
11
+ history = [self.tokenizer.encode(h + self.tokenizer.eos_token, return_tensors="pt") for h in history]
12
+ prompt = torch.cat(history + [inputs], dim=-1)
13
+ return prompt
14
+
15
+ def generate_response(self, prompt, max_length):
16
+ output = self.model.generate(prompt, max_length=max_length, pad_token_id=self.tokenizer.pad_token_id)
17
+ response = self.tokenizer.decode(output[0], skip_special_tokens=True)
18
+ return response