Spaces:
Sleeping
Sleeping
update
Browse files
app.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
from langchain.chains import ConversationChain
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import gradio as gr
|
|
|
4 |
|
5 |
-
# Define the LangChain chat agent
|
6 |
model_name = "microsoft/DialoGPT-medium"
|
7 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
8 |
-
|
|
|
|
|
9 |
|
10 |
-
agent = ConversationChain(model=model, tokenizer=tokenizer)
|
11 |
|
12 |
# Define the Gradio interface
|
13 |
def chatbot_interface(input_text):
|
|
|
1 |
from langchain.chains import ConversationChain
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import gradio as gr
|
4 |
+
from myLLM import AutoModelLanguageModel
|
5 |
|
|
|
6 |
model_name = "microsoft/DialoGPT-medium"
|
7 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
8 |
+
llm = AutoModelLanguageModel(model_name)
|
9 |
+
|
10 |
+
agent = ConversationChain(llm=llm)
|
11 |
|
|
|
12 |
|
13 |
# Define the Gradio interface
|
14 |
def chatbot_interface(input_text):
|
myLLM.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
2 |
+
from langchain.chains import LanguageModel
|
3 |
+
|
4 |
+
class AutoModelLanguageModel(LanguageModel):
|
5 |
+
def __init__(self, model_name_or_path):
|
6 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
7 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
8 |
+
|
9 |
+
def generate_prompt(self, input_text, history):
|
10 |
+
inputs = self.tokenizer.encode(input_text + self.tokenizer.eos_token, return_tensors="pt")
|
11 |
+
history = [self.tokenizer.encode(h + self.tokenizer.eos_token, return_tensors="pt") for h in history]
|
12 |
+
prompt = torch.cat(history + [inputs], dim=-1)
|
13 |
+
return prompt
|
14 |
+
|
15 |
+
def generate_response(self, prompt, max_length):
|
16 |
+
output = self.model.generate(prompt, max_length=max_length, pad_token_id=self.tokenizer.pad_token_id)
|
17 |
+
response = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
18 |
+
return response
|