Spaces:
Running
Running
Updated model
#2
by
Neha13
- opened
model2.py
CHANGED
@@ -1,134 +1,134 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
from langchain.text_splitter import CharacterTextSplitter
|
4 |
-
from langchain_huggingface import HuggingFaceEmbeddings
|
5 |
-
from langchain_community.vectorstores import FAISS
|
6 |
-
from langchain.memory import ConversationBufferMemory
|
7 |
-
from langchain.prompts import PromptTemplate
|
8 |
-
from langchain.chains import RetrievalQA
|
9 |
-
from langchain_community.document_loaders import PyPDFLoader
|
10 |
-
from langchain_groq import ChatGroq
|
11 |
-
|
12 |
-
|
13 |
-
class ChatbotModel:
|
14 |
-
def __init__(self):
|
15 |
-
# Initialize the environment variable for the GROQ API Key
|
16 |
-
os.environ["GROQ_API_KEY"] = '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'
|
17 |
-
|
18 |
-
# Load documents from PDFs
|
19 |
-
pdf_folder_path = "acpc_data"
|
20 |
-
documents = []
|
21 |
-
for file in os.listdir(pdf_folder_path):
|
22 |
-
if file.endswith('.pdf'):
|
23 |
-
pdf_path = os.path.join(pdf_folder_path, file)
|
24 |
-
loader = PyPDFLoader(pdf_path)
|
25 |
-
documents.extend(loader.load())
|
26 |
-
|
27 |
-
# Initialize embeddings
|
28 |
-
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
|
29 |
-
model_kwargs={'device': 'cpu'},
|
30 |
-
encode_kwargs={'normalize_embeddings': True})
|
31 |
-
|
32 |
-
# Split documents into chunks
|
33 |
-
self.text_splitter = CharacterTextSplitter(
|
34 |
-
separator="\n",
|
35 |
-
chunk_size=1200,
|
36 |
-
chunk_overlap=500,
|
37 |
-
length_function=len)
|
38 |
-
self.text_chunks = self.text_splitter.split_documents(documents)
|
39 |
-
|
40 |
-
# Create FAISS vector store
|
41 |
-
self.db1 = FAISS.from_documents(self.text_chunks, self.embeddings)
|
42 |
-
|
43 |
-
# Initialize memory for conversation
|
44 |
-
self.memory = ConversationBufferMemory(memory_key="history", input_key="question")
|
45 |
-
|
46 |
-
# Initialize the chat model
|
47 |
-
self.llm = ChatGroq(
|
48 |
-
model="llama-3.1-8b-instant",
|
49 |
-
temperature=0.655,
|
50 |
-
max_tokens=None,
|
51 |
-
timeout=None,
|
52 |
-
max_retries=2,
|
53 |
-
)
|
54 |
-
|
55 |
-
# Create the QA chain prompt template
|
56 |
-
self.template = """You are a smart and helpful assistant for the ACPC counseling process. You guide students and solve their queries related to ACPC, MYSY scholarship, admission, etc. You will be given the student's query and the history of the chat, and you need to answer the query to the best of your knowledge. If the query is completely different from the context then tell the student that you are not made to answer this query in a polite language. If the student has included any type of content related to violence, sex, drugs or used abusive language then tell the student that you can not answer that query and request them not to use such content.
|
57 |
-
|
58 |
-
Also make sure to reply in the same language as used by the student in the current query.
|
59 |
-
|
60 |
-
NOTE that your answer should be accurate. Explain the answer such that a student with no idea about the ACPC can understand well.
|
61 |
-
|
62 |
-
For example,
|
63 |
-
|
64 |
-
Example 1
|
65 |
-
|
66 |
-
Chat history:
|
67 |
-
|
68 |
-
|
69 |
-
Question:
|
70 |
-
What is the maximum size of passport size photo allowed?
|
71 |
-
|
72 |
-
Answer:
|
73 |
-
The maximum size of passport size photo allowed is 200 KB.
|
74 |
-
|
75 |
-
{context}
|
76 |
-
|
77 |
-
------
|
78 |
-
Chat history :
|
79 |
-
{history}
|
80 |
-
|
81 |
-
------
|
82 |
-
|
83 |
-
|
84 |
-
Question: {question}
|
85 |
-
Answer:
|
86 |
-
"""
|
87 |
-
|
88 |
-
self.QA_CHAIN_PROMPT = PromptTemplate(input_variables=["history", "context", "question"],
|
89 |
-
template=self.template)
|
90 |
-
self.qa_chain = RetrievalQA.from_chain_type(self.llm,
|
91 |
-
retriever=self.db1.as_retriever(),
|
92 |
-
chain_type='stuff',
|
93 |
-
verbose=True,
|
94 |
-
chain_type_kwargs={"verbose": True, "prompt": self.QA_CHAIN_PROMPT,
|
95 |
-
"memory": self.memory})
|
96 |
-
|
97 |
-
def save(self, path):
|
98 |
-
# Save only the necessary parameters
|
99 |
-
torch.save({
|
100 |
-
'text_chunks': self.text_chunks,
|
101 |
-
'embeddings_model_name': self.embeddings.model_name, # Save the model name
|
102 |
-
'faiss_index': self.db1.index # Save FAISS index if needed
|
103 |
-
}, path)
|
104 |
-
|
105 |
-
def get_response(self, user_input):
|
106 |
-
# Call the QA chain with the user's input
|
107 |
-
result = self.qa_chain({"query": user_input})
|
108 |
-
return result["result"]
|
109 |
-
|
110 |
-
@classmethod
|
111 |
-
def load(cls, path):
|
112 |
-
# Load the model state
|
113 |
-
state = torch.load(path)
|
114 |
-
chatbot_model = cls()
|
115 |
-
# Restore other components
|
116 |
-
chatbot_model.text_chunks = state['text_chunks']
|
117 |
-
|
118 |
-
# Recreate embeddings using the saved model name
|
119 |
-
chatbot_model.embeddings = HuggingFaceEmbeddings(model_name=state['embeddings_model_name'],
|
120 |
-
model_kwargs={'device': 'cpu'},
|
121 |
-
encode_kwargs={'normalize_embeddings': True})
|
122 |
-
|
123 |
-
# Recreate FAISS index if necessary
|
124 |
-
chatbot_model.db1 = FAISS.from_documents(chatbot_model.text_chunks, chatbot_model.embeddings)
|
125 |
-
return chatbot_model
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
# Test saving the model
|
131 |
-
if __name__ == "__main__":
|
132 |
-
chatbot = ChatbotModel()
|
133 |
-
chatbot.save("model.pt")
|
134 |
-
print("Model saved successfully.")
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from langchain.text_splitter import CharacterTextSplitter
|
4 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
5 |
+
from langchain_community.vectorstores import FAISS
|
6 |
+
from langchain.memory import ConversationBufferMemory
|
7 |
+
from langchain.prompts import PromptTemplate
|
8 |
+
from langchain.chains import RetrievalQA
|
9 |
+
from langchain_community.document_loaders import PyPDFLoader
|
10 |
+
from langchain_groq import ChatGroq
|
11 |
+
|
12 |
+
|
13 |
+
class ChatbotModel:
|
14 |
+
def __init__(self):
|
15 |
+
# Initialize the environment variable for the GROQ API Key
|
16 |
+
os.environ["GROQ_API_KEY"] = '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'
|
17 |
+
|
18 |
+
# Load documents from PDFs
|
19 |
+
pdf_folder_path = "acpc_data"
|
20 |
+
documents = []
|
21 |
+
for file in os.listdir(pdf_folder_path):
|
22 |
+
if file.endswith('.pdf'):
|
23 |
+
pdf_path = os.path.join(pdf_folder_path, file)
|
24 |
+
loader = PyPDFLoader(pdf_path)
|
25 |
+
documents.extend(loader.load())
|
26 |
+
|
27 |
+
# Initialize embeddings
|
28 |
+
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
|
29 |
+
model_kwargs={'device': 'cpu'},
|
30 |
+
encode_kwargs={'normalize_embeddings': True})
|
31 |
+
|
32 |
+
# Split documents into chunks
|
33 |
+
self.text_splitter = CharacterTextSplitter(
|
34 |
+
separator="\n",
|
35 |
+
chunk_size=1200,
|
36 |
+
chunk_overlap=500,
|
37 |
+
length_function=len)
|
38 |
+
self.text_chunks = self.text_splitter.split_documents(documents)
|
39 |
+
|
40 |
+
# Create FAISS vector store
|
41 |
+
self.db1 = FAISS.from_documents(self.text_chunks, self.embeddings)
|
42 |
+
|
43 |
+
# Initialize memory for conversation
|
44 |
+
self.memory = ConversationBufferMemory(memory_key="history", input_key="question")
|
45 |
+
|
46 |
+
# Initialize the chat model
|
47 |
+
self.llm = ChatGroq(
|
48 |
+
model="llama-3.1-8b-instant",
|
49 |
+
temperature=0.655,
|
50 |
+
max_tokens=None,
|
51 |
+
timeout=None,
|
52 |
+
max_retries=2,
|
53 |
+
)
|
54 |
+
|
55 |
+
# Create the QA chain prompt template
|
56 |
+
self.template = """You are a smart and helpful assistant for the ACPC counseling process. You guide students and solve their queries related to ACPC, MYSY scholarship, admission, etc. You will be given the student's query and the history of the chat, and you need to answer the query to the best of your knowledge. If the query is completely different from the context then tell the student that you are not made to answer this query in a polite language. If the student has included any type of content related to violence, sex, drugs or used abusive language then tell the student that you can not answer that query and request them not to use such content.
|
57 |
+
|
58 |
+
Also make sure to reply in the same language as used by the student in the current query.
|
59 |
+
|
60 |
+
NOTE that your answer should be accurate. Explain the answer such that a student with no idea about the ACPC can understand well.
|
61 |
+
|
62 |
+
For example,
|
63 |
+
|
64 |
+
Example 1
|
65 |
+
|
66 |
+
Chat history:
|
67 |
+
|
68 |
+
|
69 |
+
Question:
|
70 |
+
What is the maximum size of passport size photo allowed?
|
71 |
+
|
72 |
+
Answer:
|
73 |
+
The maximum size of passport size photo allowed is 200 KB.
|
74 |
+
|
75 |
+
{context}
|
76 |
+
|
77 |
+
------
|
78 |
+
Chat history :
|
79 |
+
{history}
|
80 |
+
|
81 |
+
------
|
82 |
+
|
83 |
+
|
84 |
+
Question: {question}
|
85 |
+
Answer:
|
86 |
+
"""
|
87 |
+
|
88 |
+
self.QA_CHAIN_PROMPT = PromptTemplate(input_variables=["history", "context", "question"],
|
89 |
+
template=self.template)
|
90 |
+
self.qa_chain = RetrievalQA.from_chain_type(self.llm,
|
91 |
+
retriever=self.db1.as_retriever(),
|
92 |
+
chain_type='stuff',
|
93 |
+
verbose=True,
|
94 |
+
chain_type_kwargs={"verbose": True, "prompt": self.QA_CHAIN_PROMPT,
|
95 |
+
"memory": self.memory})
|
96 |
+
|
97 |
+
def save(self, path):
|
98 |
+
# Save only the necessary parameters
|
99 |
+
torch.save({
|
100 |
+
'text_chunks': self.text_chunks,
|
101 |
+
'embeddings_model_name': self.embeddings.model_name, # Save the model name
|
102 |
+
'faiss_index': self.db1.index # Save FAISS index if needed
|
103 |
+
}, path)
|
104 |
+
|
105 |
+
def get_response(self, user_input):
|
106 |
+
# Call the QA chain with the user's input
|
107 |
+
result = self.qa_chain({"query": user_input})
|
108 |
+
return result["result"]
|
109 |
+
|
110 |
+
@classmethod
|
111 |
+
def load(cls, path):
|
112 |
+
# Load the model state
|
113 |
+
state = torch.load(path)
|
114 |
+
chatbot_model = cls()
|
115 |
+
# Restore other components
|
116 |
+
chatbot_model.text_chunks = state['text_chunks']
|
117 |
+
|
118 |
+
# Recreate embeddings using the saved model name
|
119 |
+
chatbot_model.embeddings = HuggingFaceEmbeddings(model_name=state['embeddings_model_name'],
|
120 |
+
model_kwargs={'device': 'cpu'},
|
121 |
+
encode_kwargs={'normalize_embeddings': True})
|
122 |
+
|
123 |
+
# Recreate FAISS index if necessary
|
124 |
+
chatbot_model.db1 = FAISS.from_documents(chatbot_model.text_chunks, chatbot_model.embeddings)
|
125 |
+
return chatbot_model
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
# Test saving the model
|
131 |
+
if __name__ == "__main__":
|
132 |
+
chatbot = ChatbotModel()
|
133 |
+
chatbot.save("model.pt")
|
134 |
+
print("Model saved successfully.")
|