Zwea Htet commited on
Commit
991fc6b
1 Parent(s): 215cfd3
Files changed (1) hide show
  1. models/llamaCustom.py +29 -19
models/llamaCustom.py CHANGED
@@ -36,25 +36,32 @@ NUM_OUTPUT = 525
36
  # set maximum chunk overlap
37
  CHUNK_OVERLAP_RATION = 0.2
38
 
39
- llm_model_name = "bigscience/bloom-560m"
40
- tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
41
- model = AutoModelForCausalLM.from_pretrained(llm_model_name, config="T5Config")
42
-
43
- model_pipeline = pipeline(
44
- model=model,
45
- tokenizer=tokenizer,
46
- task="text-generation",
47
- # device=0, # GPU device number
48
- # max_length=512,
49
- do_sample=True,
50
- top_p=0.95,
51
- top_k=50,
52
- temperature=0.7,
53
- )
 
 
 
 
 
54
 
55
 
56
  class CustomLLM(LLM):
57
- pipeline = model_pipeline
 
 
58
 
59
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
60
  prompt_length = len(prompt)
@@ -65,17 +72,19 @@ class CustomLLM(LLM):
65
 
66
  @property
67
  def _identifying_params(self) -> Mapping[str, Any]:
68
- return {"name_of_model": llm_model_name}
69
 
70
  @property
71
  def _llm_type(self) -> str:
72
  return "custom"
73
 
 
74
  class LlamaCustom:
75
  def __init__(self, model_name: str) -> None:
76
  self.vector_index = self.initialize_index(model_name=model_name)
77
 
78
- def initialize_index(self, model_name: str):
 
79
  index_name = model_name.split("/")[-1]
80
 
81
  file_path = f"./vectorStores/{index_name}"
@@ -97,7 +106,8 @@ class LlamaCustom:
97
  num_output=NUM_OUTPUT,
98
  chunk_overlap_ratio=CHUNK_OVERLAP_RATION,
99
  )
100
- llm_predictor = LLMPredictor(llm=CustomLLM())
 
101
  service_context = ServiceContext.from_defaults(
102
  llm_predictor=llm_predictor, prompt_helper=prompt_helper
103
  )
 
36
  # set maximum chunk overlap
37
  CHUNK_OVERLAP_RATION = 0.2
38
 
39
+
40
+ @st.cache_resource
41
+ def load_model(mode_name: str):
42
+ # llm_model_name = "bigscience/bloom-560m"
43
+ tokenizer = AutoTokenizer.from_pretrained(mode_name)
44
+ model = AutoModelForCausalLM.from_pretrained(mode_name, config="T5Config")
45
+
46
+ pipe = pipeline(
47
+ task="text-generation",
48
+ model=model,
49
+ tokenizer=tokenizer,
50
+ # device=0, # GPU device number
51
+ # max_length=512,
52
+ do_sample=True,
53
+ top_p=0.95,
54
+ top_k=50,
55
+ temperature=0.7,
56
+ )
57
+
58
+ return pipe
59
 
60
 
61
  class CustomLLM(LLM):
62
+ def __init__(self, model_name: str):
63
+ self.llm_model_name = model_name
64
+ self.pipeline = load_model(mode_name=model_name)
65
 
66
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
67
  prompt_length = len(prompt)
 
72
 
73
  @property
74
  def _identifying_params(self) -> Mapping[str, Any]:
75
+ return {"name_of_model": self.llm_model_name}
76
 
77
  @property
78
  def _llm_type(self) -> str:
79
  return "custom"
80
 
81
+
82
  class LlamaCustom:
83
  def __init__(self, model_name: str) -> None:
84
  self.vector_index = self.initialize_index(model_name=model_name)
85
 
86
+ @st.cache_resource
87
+ def initialize_index(_self, model_name: str):
88
  index_name = model_name.split("/")[-1]
89
 
90
  file_path = f"./vectorStores/{index_name}"
 
106
  num_output=NUM_OUTPUT,
107
  chunk_overlap_ratio=CHUNK_OVERLAP_RATION,
108
  )
109
+
110
+ llm_predictor = LLMPredictor(llm=CustomLLM(model_name=model_name))
111
  service_context = ServiceContext.from_defaults(
112
  llm_predictor=llm_predictor, prompt_helper=prompt_helper
113
  )