Zwea Htet commited on
Commit
e594eb9
1 Parent(s): 5e8fa58

update llama custom

Browse files
Files changed (1) hide show
  1. models/llamaCustom.py +69 -31
models/llamaCustom.py CHANGED
@@ -58,34 +58,79 @@ def load_model(model_name: str):
58
  return pipe
59
 
60
 
61
- class CustomLLM(LLM):
62
- llm_model_name: str
63
- pipeline: Any
 
 
64
 
65
- def __init__(self, llm_model_name: str):
66
- super().__init__()
67
-
68
- self.llm_model_name = llm_model_name
69
- self.pipeline = load_model(mode_name=llm_model_name)
 
 
 
 
 
 
70
 
71
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  prompt_length = len(prompt)
73
- response = self.pipeline(prompt, max_new_tokens=525)[0]["generated_text"]
74
 
75
  # only return newly generated tokens
76
- return response[prompt_length:]
 
77
 
78
- @property
79
- def _identifying_params(self) -> Mapping[str, Any]:
80
- return {"name_of_model": self.llm_model_name}
81
 
82
- @property
83
- def _llm_type(self) -> str:
84
- return "custom"
 
 
 
85
 
 
 
 
 
 
 
 
86
 
87
  class LlamaCustom:
 
 
 
 
 
88
  def __init__(self, model_name: str) -> None:
 
 
 
 
 
89
  self.vector_index = self.initialize_index(model_name=model_name)
90
 
91
  @st.cache_resource
@@ -93,6 +138,7 @@ class LlamaCustom:
93
  index_name = model_name.split("/")[-1]
94
 
95
  file_path = f"./vectorStores/{index_name}"
 
96
  if os.path.exists(path=file_path):
97
  # rebuild storage context
98
  storage_context = StorageContext.from_defaults(persist_dir=file_path)
@@ -105,23 +151,11 @@ class LlamaCustom:
105
  # index = pickle.loads(file.readlines())
106
  return index
107
  else:
108
- # define llm
109
- prompt_helper = PromptHelper(
110
- context_window=CONTEXT_WINDOW,
111
- num_output=NUM_OUTPUT,
112
- chunk_overlap_ratio=CHUNK_OVERLAP_RATION,
113
- )
114
-
115
- llm_predictor = LLMPredictor(llm=CustomLLM(llm_model_name=model_name))
116
- service_context = ServiceContext.from_defaults(
117
- llm_predictor=llm_predictor, prompt_helper=prompt_helper
118
- )
119
-
120
  # documents = prepare_data(r"./assets/regItems.json")
121
  documents = SimpleDirectoryReader(input_dir="./assets/pdf").load_data()
122
 
123
  index = GPTVectorStoreIndex.from_documents(
124
- documents, service_context=service_context
125
  )
126
 
127
  # local write access
@@ -134,6 +168,10 @@ class LlamaCustom:
134
 
135
  def get_response(self, query_str):
136
  print("query_str: ", query_str)
137
- query_engine = self.vector_index.as_query_engine()
 
 
 
138
  response = query_engine.query(query_str)
 
139
  return str(response)
 
58
  return pipe
59
 
60
 
61
+ @st.cache_resource
62
+ def load_model(mode_name: str):
63
+ # llm_model_name = "bigscience/bloom-560m"
64
+ tokenizer = AutoTokenizer.from_pretrained(mode_name)
65
+ model = AutoModelForCausalLM.from_pretrained(mode_name, config="T5Config")
66
 
67
+ pipe = pipeline(
68
+ task="text-generation",
69
+ model=model,
70
+ tokenizer=tokenizer,
71
+ # device=0, # GPU device number
72
+ # max_length=512,
73
+ do_sample=True,
74
+ top_p=0.95,
75
+ top_k=50,
76
+ temperature=0.7,
77
+ )
78
 
79
+ return pipe
80
+
81
+
82
+ class OurLLM(CustomLLM):
83
+ def __init__(self, model_name: str, model_pipeline):
84
+ self.model_name = model_name
85
+ self.pipeline = model_pipeline
86
+
87
+ @property
88
+ def metadata(self) -> LLMMetadata:
89
+ """Get LLM metadata."""
90
+ return LLMMetadata(
91
+ context_window=CONTEXT_WINDOW,
92
+ num_output=NUM_OUTPUT,
93
+ model_name=self.model_name,
94
+ )
95
+
96
+ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
97
  prompt_length = len(prompt)
98
+ response = self.pipeline(prompt, max_new_tokens=NUM_OUTPUT)[0]["generated_text"]
99
 
100
  # only return newly generated tokens
101
+ text = response[prompt_length:]
102
+ return CompletionResponse(text=text)
103
 
104
+ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
105
+ raise NotImplementedError()
 
106
 
107
+ # def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
108
+ # prompt_length = len(prompt)
109
+ # response = self.pipeline(prompt, max_new_tokens=525)[0]["generated_text"]
110
+
111
+ # # only return newly generated tokens
112
+ # return response[prompt_length:]
113
 
114
+ # @property
115
+ # def _identifying_params(self) -> Mapping[str, Any]:
116
+ # return {"name_of_model": self.model_name}
117
+
118
+ # @property
119
+ # def _llm_type(self) -> str:
120
+ # return "custom"
121
 
122
  class LlamaCustom:
123
+ # define llm
124
+ # llm_predictor = LLMPredictor(llm=OurLLM())
125
+ # service_context = ServiceContext.from_defaults(
126
+ # llm_predictor=llm_predictor, prompt_helper=prompt_helper
127
+ # )
128
  def __init__(self, model_name: str) -> None:
129
+ pipe = load_model(mode_name=model_name)
130
+ llm = OurLLM(model_name=model_name, model_pipeline=pipe)
131
+ self.service_context = ServiceContext.from_defaults(
132
+ llm=llm, prompt_helper=prompt_helper
133
+ )
134
  self.vector_index = self.initialize_index(model_name=model_name)
135
 
136
  @st.cache_resource
 
138
  index_name = model_name.split("/")[-1]
139
 
140
  file_path = f"./vectorStores/{index_name}"
141
+
142
  if os.path.exists(path=file_path):
143
  # rebuild storage context
144
  storage_context = StorageContext.from_defaults(persist_dir=file_path)
 
151
  # index = pickle.loads(file.readlines())
152
  return index
153
  else:
 
 
 
 
 
 
 
 
 
 
 
 
154
  # documents = prepare_data(r"./assets/regItems.json")
155
  documents = SimpleDirectoryReader(input_dir="./assets/pdf").load_data()
156
 
157
  index = GPTVectorStoreIndex.from_documents(
158
+ documents, service_context=self.service_context
159
  )
160
 
161
  # local write access
 
168
 
169
  def get_response(self, query_str):
170
  print("query_str: ", query_str)
171
+ # query_engine = self.vector_index.as_query_engine()
172
+ query_engine = self.vector_index.as_query_engine(
173
+ text_qa_template=text_qa_template, refine_template=refine_template
174
+ )
175
  response = query_engine.query(query_str)
176
+ print("metadata: ", response.metadata)
177
  return str(response)