Neha13 commited on
Commit
acbb4ac
1 Parent(s): 0531b90

Updated model

Browse files
Files changed (1) hide show
  1. model2.py +134 -134
model2.py CHANGED
@@ -1,134 +1,134 @@
1
- import os
2
- import torch
3
- from langchain.text_splitter import CharacterTextSplitter
4
- from langchain_huggingface import HuggingFaceEmbeddings
5
- from langchain_community.vectorstores import FAISS
6
- from langchain.memory import ConversationBufferMemory
7
- from langchain.prompts import PromptTemplate
8
- from langchain.chains import RetrievalQA
9
- from langchain_community.document_loaders import PyPDFLoader
10
- from langchain_groq import ChatGroq
11
-
12
-
13
- class ChatbotModel:
14
- def __init__(self):
15
- # Initialize the environment variable for the GROQ API Key
16
- os.environ["GROQ_API_KEY"] = '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'
17
-
18
- # Load documents from PDFs
19
- pdf_folder_path = "acpc_data"
20
- documents = []
21
- for file in os.listdir(pdf_folder_path):
22
- if file.endswith('.pdf'):
23
- pdf_path = os.path.join(pdf_folder_path, file)
24
- loader = PyPDFLoader(pdf_path)
25
- documents.extend(loader.load())
26
-
27
- # Initialize embeddings
28
- self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
29
- model_kwargs={'device': 'cpu'},
30
- encode_kwargs={'normalize_embeddings': True})
31
-
32
- # Split documents into chunks
33
- self.text_splitter = CharacterTextSplitter(
34
- separator="\n",
35
- chunk_size=1200,
36
- chunk_overlap=500,
37
- length_function=len)
38
- self.text_chunks = self.text_splitter.split_documents(documents)
39
-
40
- # Create FAISS vector store
41
- self.db1 = FAISS.from_documents(self.text_chunks, self.embeddings)
42
-
43
- # Initialize memory for conversation
44
- self.memory = ConversationBufferMemory(memory_key="history", input_key="question")
45
-
46
- # Initialize the chat model
47
- self.llm = ChatGroq(
48
- model="llama-3.1-8b-instant",
49
- temperature=0.655,
50
- max_tokens=None,
51
- timeout=None,
52
- max_retries=2,
53
- )
54
-
55
- # Create the QA chain prompt template
56
- self.template = """You are a smart and helpful assistant for the ACPC counseling process. You guide students and solve their queries related to ACPC, MYSY scholarship, admission, etc. You will be given the student's query and the history of the chat, and you need to answer the query to the best of your knowledge. If the query is completely different from the context then tell the student that you are not made to answer this query in a polite language. If the student has included any type of content related to violence, sex, drugs or used abusive language then tell the student that you can not answer that query and request them not to use such content.
57
-
58
- Also make sure to reply in the same language as used by the student in the current query.
59
-
60
- NOTE that your answer should be accurate. Explain the answer such that a student with no idea about the ACPC can understand well.
61
-
62
- For example,
63
-
64
- Example 1
65
-
66
- Chat history:
67
- The student named Priya says hello.
68
-
69
- Question:
70
- What is the maximum size of passport size photo allowed?
71
-
72
- Answer:
73
- The maximum size of passport size photo allowed is 200 KB.
74
-
75
- {context}
76
-
77
- ------
78
- Chat history :
79
- {history}
80
-
81
- ------
82
-
83
-
84
- Question: {question}
85
- Answer:
86
- """
87
-
88
- self.QA_CHAIN_PROMPT = PromptTemplate(input_variables=["history", "context", "question"],
89
- template=self.template)
90
- self.qa_chain = RetrievalQA.from_chain_type(self.llm,
91
- retriever=self.db1.as_retriever(),
92
- chain_type='stuff',
93
- verbose=True,
94
- chain_type_kwargs={"verbose": True, "prompt": self.QA_CHAIN_PROMPT,
95
- "memory": self.memory})
96
-
97
- def save(self, path):
98
- # Save only the necessary parameters
99
- torch.save({
100
- 'text_chunks': self.text_chunks,
101
- 'embeddings_model_name': self.embeddings.model_name, # Save the model name
102
- 'faiss_index': self.db1.index # Save FAISS index if needed
103
- }, path)
104
-
105
- def get_response(self, user_input):
106
- # Call the QA chain with the user's input
107
- result = self.qa_chain({"query": user_input})
108
- return result["result"]
109
-
110
- @classmethod
111
- def load(cls, path):
112
- # Load the model state
113
- state = torch.load(path)
114
- chatbot_model = cls()
115
- # Restore other components
116
- chatbot_model.text_chunks = state['text_chunks']
117
-
118
- # Recreate embeddings using the saved model name
119
- chatbot_model.embeddings = HuggingFaceEmbeddings(model_name=state['embeddings_model_name'],
120
- model_kwargs={'device': 'cpu'},
121
- encode_kwargs={'normalize_embeddings': True})
122
-
123
- # Recreate FAISS index if necessary
124
- chatbot_model.db1 = FAISS.from_documents(chatbot_model.text_chunks, chatbot_model.embeddings)
125
- return chatbot_model
126
-
127
-
128
-
129
-
130
- # Test saving the model
131
- if __name__ == "__main__":
132
- chatbot = ChatbotModel()
133
- chatbot.save("model.pt")
134
- print("Model saved successfully.")
 
1
+ import os
2
+ import torch
3
+ from langchain.text_splitter import CharacterTextSplitter
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain.memory import ConversationBufferMemory
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain.chains import RetrievalQA
9
+ from langchain_community.document_loaders import PyPDFLoader
10
+ from langchain_groq import ChatGroq
11
+
12
+
13
+ class ChatbotModel:
14
+ def __init__(self):
15
+ # Initialize the environment variable for the GROQ API Key
16
+ os.environ["GROQ_API_KEY"] = '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'
17
+
18
+ # Load documents from PDFs
19
+ pdf_folder_path = "acpc_data"
20
+ documents = []
21
+ for file in os.listdir(pdf_folder_path):
22
+ if file.endswith('.pdf'):
23
+ pdf_path = os.path.join(pdf_folder_path, file)
24
+ loader = PyPDFLoader(pdf_path)
25
+ documents.extend(loader.load())
26
+
27
+ # Initialize embeddings
28
+ self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
29
+ model_kwargs={'device': 'cpu'},
30
+ encode_kwargs={'normalize_embeddings': True})
31
+
32
+ # Split documents into chunks
33
+ self.text_splitter = CharacterTextSplitter(
34
+ separator="\n",
35
+ chunk_size=1200,
36
+ chunk_overlap=500,
37
+ length_function=len)
38
+ self.text_chunks = self.text_splitter.split_documents(documents)
39
+
40
+ # Create FAISS vector store
41
+ self.db1 = FAISS.from_documents(self.text_chunks, self.embeddings)
42
+
43
+ # Initialize memory for conversation
44
+ self.memory = ConversationBufferMemory(memory_key="history", input_key="question")
45
+
46
+ # Initialize the chat model
47
+ self.llm = ChatGroq(
48
+ model="llama-3.1-8b-instant",
49
+ temperature=0.655,
50
+ max_tokens=None,
51
+ timeout=None,
52
+ max_retries=2,
53
+ )
54
+
55
+ # Create the QA chain prompt template
56
+ self.template = """You are a smart and helpful assistant for the ACPC counseling process. You guide students and solve their queries related to ACPC, MYSY scholarship, admission, etc. You will be given the student's query and the history of the chat, and you need to answer the query to the best of your knowledge. If the query is completely different from the context then tell the student that you are not made to answer this query in a polite language. If the student has included any type of content related to violence, sex, drugs or used abusive language then tell the student that you can not answer that query and request them not to use such content.
57
+
58
+ Also make sure to reply in the same language as used by the student in the current query.
59
+
60
+ NOTE that your answer should be accurate. Explain the answer such that a student with no idea about the ACPC can understand well.
61
+
62
+ For example,
63
+
64
+ Example 1
65
+
66
+ Chat history:
67
+
68
+
69
+ Question:
70
+ What is the maximum size of passport size photo allowed?
71
+
72
+ Answer:
73
+ The maximum size of passport size photo allowed is 200 KB.
74
+
75
+ {context}
76
+
77
+ ------
78
+ Chat history :
79
+ {history}
80
+
81
+ ------
82
+
83
+
84
+ Question: {question}
85
+ Answer:
86
+ """
87
+
88
+ self.QA_CHAIN_PROMPT = PromptTemplate(input_variables=["history", "context", "question"],
89
+ template=self.template)
90
+ self.qa_chain = RetrievalQA.from_chain_type(self.llm,
91
+ retriever=self.db1.as_retriever(),
92
+ chain_type='stuff',
93
+ verbose=True,
94
+ chain_type_kwargs={"verbose": True, "prompt": self.QA_CHAIN_PROMPT,
95
+ "memory": self.memory})
96
+
97
+ def save(self, path):
98
+ # Save only the necessary parameters
99
+ torch.save({
100
+ 'text_chunks': self.text_chunks,
101
+ 'embeddings_model_name': self.embeddings.model_name, # Save the model name
102
+ 'faiss_index': self.db1.index # Save FAISS index if needed
103
+ }, path)
104
+
105
+ def get_response(self, user_input):
106
+ # Call the QA chain with the user's input
107
+ result = self.qa_chain({"query": user_input})
108
+ return result["result"]
109
+
110
+ @classmethod
111
+ def load(cls, path):
112
+ # Load the model state
113
+ state = torch.load(path)
114
+ chatbot_model = cls()
115
+ # Restore other components
116
+ chatbot_model.text_chunks = state['text_chunks']
117
+
118
+ # Recreate embeddings using the saved model name
119
+ chatbot_model.embeddings = HuggingFaceEmbeddings(model_name=state['embeddings_model_name'],
120
+ model_kwargs={'device': 'cpu'},
121
+ encode_kwargs={'normalize_embeddings': True})
122
+
123
+ # Recreate FAISS index if necessary
124
+ chatbot_model.db1 = FAISS.from_documents(chatbot_model.text_chunks, chatbot_model.embeddings)
125
+ return chatbot_model
126
+
127
+
128
+
129
+
130
+ # Test saving the model
131
+ if __name__ == "__main__":
132
+ chatbot = ChatbotModel()
133
+ chatbot.save("model.pt")
134
+ print("Model saved successfully.")