Zwea Htet commited on
Commit
13c5bb4
1 Parent(s): 19f4fce

updated code for llamaCustom

Browse files
models/langOpen.py CHANGED
@@ -4,10 +4,14 @@ import openai
4
  from dotenv import load_dotenv
5
  from langchain.chains import LLMChain
6
  from langchain.chat_models import ChatOpenAI
 
7
  from langchain.embeddings.openai import OpenAIEmbeddings
8
  from langchain.prompts import PromptTemplate
9
  from langchain.vectorstores import FAISS
10
 
 
 
 
11
  load_dotenv()
12
 
13
  embeddings = OpenAIEmbeddings()
@@ -31,9 +35,7 @@ class LangOpen:
31
  if os.path.exists(path=path):
32
  return FAISS.load_local(folder_path=path, embeddings=embeddings)
33
  else:
34
- faiss = FAISS.from_texts(
35
- "./assets/updated_calregs.txt", embedding=embeddings
36
- )
37
  faiss.save_local(path)
38
  return faiss
39
 
 
4
  from dotenv import load_dotenv
5
  from langchain.chains import LLMChain
6
  from langchain.chat_models import ChatOpenAI
7
+ from langchain.document_loaders import PyPDFLoader
8
  from langchain.embeddings.openai import OpenAIEmbeddings
9
  from langchain.prompts import PromptTemplate
10
  from langchain.vectorstores import FAISS
11
 
12
+ loader = PyPDFLoader("./assets/pdf/CADWReg.pdf")
13
+ pages = loader.load_and_split()
14
+
15
  load_dotenv()
16
 
17
  embeddings = OpenAIEmbeddings()
 
35
  if os.path.exists(path=path):
36
  return FAISS.load_local(folder_path=path, embeddings=embeddings)
37
  else:
38
+ faiss = FAISS.from_documents(pages, embeddings)
 
 
39
  faiss.save_local(path)
40
  return faiss
41
 
models/llamaCustom.py CHANGED
@@ -6,9 +6,9 @@ from typing import Any, List, Mapping, Optional
6
  import numpy as np
7
  import openai
8
  import pandas as pd
 
9
  from dotenv import load_dotenv
10
- from huggingface_hub import HfFileSystem
11
- from langchain.llms.base import LLM
12
  from llama_index import (
13
  Document,
14
  GPTVectorStoreIndex,
@@ -19,12 +19,17 @@ from llama_index import (
19
  StorageContext,
20
  load_index_from_storage,
21
  )
22
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
 
 
 
23
 
24
  # from utils.customLLM import CustomLLM
25
 
26
  load_dotenv()
27
  # openai.api_key = os.getenv("OPENAI_API_KEY")
 
28
  fs = HfFileSystem()
29
 
30
  # define prompt helper
@@ -33,62 +38,122 @@ CONTEXT_WINDOW = 2048
33
  # set number of output tokens
34
  NUM_OUTPUT = 525
35
  # set maximum chunk overlap
36
- CHUNK_OVERLAP_RATION = 0.2
37
 
38
  prompt_helper = PromptHelper(
39
  context_window=CONTEXT_WINDOW,
40
  num_output=NUM_OUTPUT,
41
- chunk_overlap_ratio=CHUNK_OVERLAP_RATION,
42
  )
43
 
44
- llm_model_name = "bigscience/bloom-560m"
45
- tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
46
- model = AutoModelForCausalLM.from_pretrained(llm_model_name, config="T5Config")
47
-
48
- model_pipeline = pipeline(
49
- model=model,
50
- tokenizer=tokenizer,
51
- task="text-generation",
52
- # device=0, # GPU device number
53
- # max_length=512,
54
- do_sample=True,
55
- top_p=0.95,
56
- top_k=50,
57
- temperature=0.7,
58
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
 
60
 
61
- class CustomLLM(LLM):
62
- pipeline = model_pipeline
63
 
64
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  prompt_length = len(prompt)
66
- response = self.pipeline(prompt, max_new_tokens=525)[0]["generated_text"]
67
 
68
  # only return newly generated tokens
69
- return response[prompt_length:]
 
70
 
71
- @property
72
- def _identifying_params(self) -> Mapping[str, Any]:
73
- return {"name_of_model": self.model_name}
74
 
75
- @property
76
- def _llm_type(self) -> str:
77
- return "custom"
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
 
80
  class LlamaCustom:
81
  # define llm
82
- llm_predictor = LLMPredictor(llm=CustomLLM())
83
- service_context = ServiceContext.from_defaults(
84
- llm_predictor=llm_predictor, prompt_helper=prompt_helper
85
- )
86
-
87
- def __init__(self, name: str) -> None:
88
- self.vector_index = self.initialize_index(index_name=name)
 
 
 
 
 
 
 
 
89
 
90
- def initialize_index(self, index_name):
91
  file_path = f"./vectorStores/{index_name}"
 
92
  if os.path.exists(path=file_path):
93
  # rebuild storage context
94
  storage_context = StorageContext.from_defaults(persist_dir=file_path)
@@ -118,6 +183,10 @@ class LlamaCustom:
118
 
119
  def get_response(self, query_str):
120
  print("query_str: ", query_str)
121
- query_engine = self.vector_index.as_query_engine()
 
 
 
122
  response = query_engine.query(query_str)
 
123
  return str(response)
 
6
  import numpy as np
7
  import openai
8
  import pandas as pd
9
+ import streamlit as st
10
  from dotenv import load_dotenv
11
+ from huggingface_hub import HfFileSystem, Repository
 
12
  from llama_index import (
13
  Document,
14
  GPTVectorStoreIndex,
 
19
  StorageContext,
20
  load_index_from_storage,
21
  )
22
+ from llama_index.llms import CompletionResponse, CustomLLM, LLMMetadata
23
+
24
+ # from langchain.llms.base import LLM
25
+ from llama_index.prompts import Prompt
26
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, pipeline
27
 
28
  # from utils.customLLM import CustomLLM
29
 
30
  load_dotenv()
31
  # openai.api_key = os.getenv("OPENAI_API_KEY")
32
+
33
  fs = HfFileSystem()
34
 
35
  # define prompt helper
 
38
  # set number of output tokens
39
  NUM_OUTPUT = 525
40
  # set maximum chunk overlap
41
+ CHUNK_OVERLAP_RATIO = 0.2
42
 
43
  prompt_helper = PromptHelper(
44
  context_window=CONTEXT_WINDOW,
45
  num_output=NUM_OUTPUT,
46
+ chunk_overlap_ratio=CHUNK_OVERLAP_RATIO,
47
  )
48
 
49
+ text_qa_template_str = (
50
+ "Context information is below.\n"
51
+ "---------------------\n"
52
+ "{context_str}\n"
53
+ "---------------------\n"
54
+ "Using both the context information and also using your own knowledge, "
55
+ "answer the question: {query_str}\n"
56
+ "If the question is relevant, you can answer by providing the name of the chapter, the article and the title to the answer. In addition, you can add the page number of the document when you found the answer.\n"
57
+ "If the context isn't helpful, you can also answer the question on your own.\n"
 
 
 
 
 
58
  )
59
+ text_qa_template = Prompt(text_qa_template_str)
60
+
61
+ refine_template_str = (
62
+ "The original question is as follows: {query_str}\n"
63
+ "We have provided an existing answer: {existing_answer}\n"
64
+ "We have the opportunity to refine the existing answer "
65
+ "(only if needed) with some more context below.\n"
66
+ "------------\n"
67
+ "{context_msg}\n"
68
+ "------------\n"
69
+ "Using both the new context and your own knowledege, update or repeat the existing answer.\n"
70
+ )
71
+ refine_template = Prompt(refine_template_str)
72
+
73
+
74
+ @st.cache_resource
75
+ def load_model(mode_name: str):
76
+ # llm_model_name = "bigscience/bloom-560m"
77
+ tokenizer = AutoTokenizer.from_pretrained(mode_name)
78
+ model = AutoModelForCausalLM.from_pretrained(mode_name, config="T5Config")
79
+
80
+ pipe = pipeline(
81
+ task="text-generation",
82
+ model=model,
83
+ tokenizer=tokenizer,
84
+ # device=0, # GPU device number
85
+ # max_length=512,
86
+ do_sample=True,
87
+ top_p=0.95,
88
+ top_k=50,
89
+ temperature=0.7,
90
+ )
91
 
92
+ return pipe
93
 
 
 
94
 
95
+ class OurLLM(CustomLLM):
96
+ def __init__(self, model_name: str, model_pipeline):
97
+ self.model_name = model_name
98
+ self.pipeline = model_pipeline
99
+
100
+ @property
101
+ def metadata(self) -> LLMMetadata:
102
+ """Get LLM metadata."""
103
+ return LLMMetadata(
104
+ context_window=CONTEXT_WINDOW,
105
+ num_output=NUM_OUTPUT,
106
+ model_name=self.model_name,
107
+ )
108
+
109
+ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
110
  prompt_length = len(prompt)
111
+ response = self.pipeline(prompt, max_new_tokens=NUM_OUTPUT)[0]["generated_text"]
112
 
113
  # only return newly generated tokens
114
+ text = response[prompt_length:]
115
+ return CompletionResponse(text=text)
116
 
117
+ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
118
+ raise NotImplementedError()
 
119
 
120
+ # def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
121
+ # prompt_length = len(prompt)
122
+ # response = self.pipeline(prompt, max_new_tokens=525)[0]["generated_text"]
123
+
124
+ # # only return newly generated tokens
125
+ # return response[prompt_length:]
126
+
127
+ # @property
128
+ # def _identifying_params(self) -> Mapping[str, Any]:
129
+ # return {"name_of_model": self.model_name}
130
+
131
+ # @property
132
+ # def _llm_type(self) -> str:
133
+ # return "custom"
134
 
135
 
136
+ @st.cache_resource
137
  class LlamaCustom:
138
  # define llm
139
+ # llm_predictor = LLMPredictor(llm=OurLLM())
140
+ # service_context = ServiceContext.from_defaults(
141
+ # llm_predictor=llm_predictor, prompt_helper=prompt_helper
142
+ # )
143
+
144
+ def __init__(self, model_name: str) -> None:
145
+ pipe = load_model(mode_name=model_name)
146
+ llm = OurLLM(model_name=model_name, model_pipeline=pipe)
147
+ self.service_context = ServiceContext.from_defaults(
148
+ llm=llm, prompt_helper=prompt_helper
149
+ )
150
+ self.vector_index = self.initialize_index(model_name=model_name)
151
+
152
+ def initialize_index(self, model_name: str):
153
+ index_name = model_name.split("/")[-1]
154
 
 
155
  file_path = f"./vectorStores/{index_name}"
156
+
157
  if os.path.exists(path=file_path):
158
  # rebuild storage context
159
  storage_context = StorageContext.from_defaults(persist_dir=file_path)
 
183
 
184
  def get_response(self, query_str):
185
  print("query_str: ", query_str)
186
+ # query_engine = self.vector_index.as_query_engine()
187
+ query_engine = self.vector_index.as_query_engine(
188
+ text_qa_template=text_qa_template, refine_template=refine_template
189
+ )
190
  response = query_engine.query(query_str)
191
+ print("metadata: ", response.metadata)
192
  return str(response)
pages/langchain_demo.py CHANGED
@@ -17,7 +17,7 @@ if "openai_api_key" not in st.session_state:
17
  st.info("Enter your openai key to access the chatbot.")
18
  else:
19
  option = st.selectbox(
20
- label="Select your model:", options=("gpt-3.5-turbo", "gpt-4"), index=0
21
  )
22
 
23
  with st.spinner(f"Initializing {option} ..."):
 
17
  st.info("Enter your openai key to access the chatbot.")
18
  else:
19
  option = st.selectbox(
20
+ label="Select your model:", options=("gpt-3.5-turbo", "gpt-4")
21
  )
22
 
23
  with st.spinner(f"Initializing {option} ..."):
pages/llama_custom_demo.py CHANGED
@@ -1,11 +1,10 @@
1
  import os
2
- import time
3
 
4
  import openai
5
  import streamlit as st
6
 
7
  from models.llamaCustom import LlamaCustom
8
- from utils.chatbox import *
9
 
10
  st.set_page_config(page_title="Llama", page_icon="🦙")
11
 
@@ -17,7 +16,11 @@ if "messages" not in st.session_state:
17
  if "openai_api_key" not in st.session_state:
18
  st.info("Enter your openai key to access the chatbot.")
19
  else:
 
 
 
 
20
  with st.spinner("Initializing vector index"):
21
- model = LlamaCustom(name="llamaCustom")
22
 
23
  chatbox("llama_custom", model)
 
1
  import os
 
2
 
3
  import openai
4
  import streamlit as st
5
 
6
  from models.llamaCustom import LlamaCustom
7
+ from utils.chatbox import chatbox
8
 
9
  st.set_page_config(page_title="Llama", page_icon="🦙")
10
 
 
16
  if "openai_api_key" not in st.session_state:
17
  st.info("Enter your openai key to access the chatbot.")
18
  else:
19
+ option = st.selectbox(
20
+ label="Select your model:", options=("bigscience/bloom-560m",)
21
+ )
22
+
23
  with st.spinner("Initializing vector index"):
24
+ model = LlamaCustom(model_name=option)
25
 
26
  chatbox("llama_custom", model)
utils/chatbox.py CHANGED
@@ -40,7 +40,6 @@ def display_bot_msg(model_name: str, bot_response: str):
40
  {"model_name": model_name, "role": "assistant", "content": full_response}
41
  )
42
 
43
- # @st.cache_data
44
  def chatbox(model_name: str, model: None):
45
  # Display chat messages from history on app rerun
46
  for message in st.session_state.messages:
 
40
  {"model_name": model_name, "role": "assistant", "content": full_response}
41
  )
42
 
 
43
  def chatbox(model_name: str, model: None):
44
  # Display chat messages from history on app rerun
45
  for message in st.session_state.messages: