KonradSzafer commited on
Commit
988981a
1 Parent(s): 5195c5a

config update

Browse files
api/__main__.py CHANGED
@@ -6,17 +6,7 @@ from qa_engine import logger, Config, QAEngine
6
 
7
  config = Config()
8
  app = FastAPI()
9
- qa_engine = QAEngine(
10
- llm_model_id=config.question_answering_model_id,
11
- embedding_model_id=config.embedding_model_id,
12
- index_repo_id=config.index_repo_id,
13
- prompt_template=config.prompt_template,
14
- use_docs_for_context=config.use_docs_for_context,
15
- num_relevant_docs=config.num_relevant_docs,
16
- add_sources_to_response=config.add_sources_to_response,
17
- use_messages_for_context=config.use_messages_in_context,
18
- debug=config.debug
19
- )
20
 
21
 
22
  @app.get('/')
 
6
 
7
  config = Config()
8
  app = FastAPI()
9
+ qa_engine = QAEngine(config=config)
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  @app.get('/')
app.py CHANGED
@@ -8,16 +8,7 @@ from discord_bot import DiscordClient
8
 
9
 
10
  config = Config()
11
- qa_engine = QAEngine(
12
- llm_model_id=config.question_answering_model_id,
13
- embedding_model_id=config.embedding_model_id,
14
- index_repo_id=config.index_repo_id,
15
- prompt_template=config.prompt_template,
16
- use_docs_for_context=config.use_docs_for_context,
17
- add_sources_to_response=config.add_sources_to_response,
18
- use_messages_for_context=config.use_messages_in_context,
19
- debug=config.debug
20
- )
21
 
22
 
23
  def gradio_interface():
@@ -41,11 +32,7 @@ def gradio_interface():
41
  def discord_bot_inference_thread():
42
  client = DiscordClient(
43
  qa_engine=qa_engine,
44
- channel_ids=config.discord_channel_ids,
45
- num_last_messages=config.num_last_messages,
46
- use_names_in_context=config.use_names_in_context,
47
- enable_commands=config.enable_commands,
48
- debug=config.debug
49
  )
50
  client.run(config.discord_token)
51
 
 
8
 
9
 
10
  config = Config()
11
+ qa_engine = QAEngine(config=config)
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  def gradio_interface():
 
32
  def discord_bot_inference_thread():
33
  client = DiscordClient(
34
  qa_engine=qa_engine,
35
+ config=config
 
 
 
 
36
  )
37
  client.run(config.discord_token)
38
 
benchmark/__main__.py CHANGED
@@ -10,16 +10,7 @@ from qa_engine import logger, Config, QAEngine
10
  QUESTIONS_FILENAME = 'benchmark/questions.json'
11
 
12
  config = Config()
13
- qa_engine = QAEngine(
14
- llm_model_id=config.question_answering_model_id,
15
- embedding_model_id=config.embedding_model_id,
16
- index_repo_id=config.index_repo_id,
17
- prompt_template=config.prompt_template,
18
- use_docs_for_context=config.use_docs_for_context,
19
- add_sources_to_response=config.add_sources_to_response,
20
- use_messages_for_context=config.use_messages_in_context,
21
- debug=config.debug
22
- )
23
 
24
 
25
  def main():
 
10
  QUESTIONS_FILENAME = 'benchmark/questions.json'
11
 
12
  config = Config()
13
+ qa_engine = QAEngine(config=config)
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  def main():
config/.env.example CHANGED
@@ -1,7 +1,7 @@
1
  # QA engine settings
2
- QUESTION_ANSWERING_MODEL_ID=hf-question-answering-model-ID
3
- EMBEDDING_MODEL_ID=hf-embedding-model-ID
4
- INDEX_REPO_ID=hf-index-ID
5
  PROMPT_TEMPLATE_NAME=llama
6
  USE_DOCS_FOR_CONTEXT=True
7
  NUM_RELEVANT_DOCS=4
@@ -9,6 +9,14 @@ ADD_SOURCES_TO_RESPONSE=True
9
  USE_MESSAGES_IN_CONTEXT=True
10
  DEBUG=True
11
 
 
 
 
 
 
 
 
 
12
  # Discord settings
13
  DISCORD_TOKEN=your-bot-token
14
  NUM_LAST_MESSAGES=1
 
1
  # QA engine settings
2
+ QUESTION_ANSWERING_MODEL_ID=mock
3
+ EMBEDDING_MODEL_ID=hkunlp/instructor-large
4
+ INDEX_REPO_ID=KonradSzafer/index-instructor-large-812-m512-all_repos_above_50_stars
5
  PROMPT_TEMPLATE_NAME=llama
6
  USE_DOCS_FOR_CONTEXT=True
7
  NUM_RELEVANT_DOCS=4
 
9
  USE_MESSAGES_IN_CONTEXT=True
10
  DEBUG=True
11
 
12
+ # Model settings
13
+ MIN_NEW_TOKENS=64
14
+ MAX_NEW_TOKENS=800
15
+ TEMPERATURE=0.6
16
+ TOP_K=50
17
+ TOP_P=0.9
18
+ DO_SAMPLE=True
19
+
20
  # Discord settings
21
  DISCORD_TOKEN=your-bot-token
22
  NUM_LAST_MESSAGES=1
discord_bot/__main__.py CHANGED
@@ -3,24 +3,10 @@ from discord_bot.client import DiscordClient
3
 
4
 
5
  config = Config()
6
- qa_engine = QAEngine(
7
- llm_model_id=config.question_answering_model_id,
8
- embedding_model_id=config.embedding_model_id,
9
- index_repo_id=config.index_repo_id,
10
- prompt_template=config.prompt_template,
11
- use_docs_for_context=config.use_docs_for_context,
12
- num_relevant_docs=config.num_relevant_docs,
13
- add_sources_to_response=config.add_sources_to_response,
14
- use_messages_for_context=config.use_messages_in_context,
15
- debug=config.debug
16
- )
17
  client = DiscordClient(
18
  qa_engine=qa_engine,
19
- channel_ids=config.discord_channel_ids,
20
- num_last_messages=config.num_last_messages,
21
- use_names_in_context=config.use_names_in_context,
22
- enable_commands=config.enable_commands,
23
- debug=config.debug
24
  )
25
 
26
 
 
3
 
4
 
5
  config = Config()
6
+ qa_engine = QAEngine(config=config)
 
 
 
 
 
 
 
 
 
 
7
  client = DiscordClient(
8
  qa_engine=qa_engine,
9
+ config=config
 
 
 
 
10
  )
11
 
12
 
discord_bot/client/client.py CHANGED
@@ -4,56 +4,40 @@ from urllib.parse import quote
4
  import discord
5
  from typing import List
6
 
7
- from qa_engine import logger, QAEngine
8
  from discord_bot.client.utils import split_text_into_chunks
9
 
10
 
11
  class DiscordClient(discord.Client):
12
  """
13
  Discord Client class, used for interacting with a Discord server.
14
-
15
- Args:
16
- qa_service_url (str): The URL of the question answering service.
17
- num_last_messages (int, optional): The number of previous messages to use as context for generating answers.
18
- Defaults to 5.
19
- use_names_in_context (bool, optional): Whether to include user names in the message context. Defaults to True.
20
- enable_commands (bool, optional): Whether to enable commands for the bot. Defaults to True.
21
-
22
- Attributes:
23
- qa_service_url (str): The URL of the question answering service.
24
- num_last_messages (int): The number of previous messages to use as context for generating answers.
25
- use_names_in_context (bool): Whether to include user names in the message context.
26
- enable_commands (bool): Whether to enable commands for the bot.
27
- max_message_len (int): The maximum length of a message.
28
- system_prompt (str): The system prompt to be used.
29
-
30
  """
31
  def __init__(
32
  self,
33
  qa_engine: QAEngine,
34
- channel_ids: list[int] = [],
35
- num_last_messages: int = 5,
36
- use_names_in_context: bool = True,
37
- enable_commands: bool = True,
38
- debug: bool = False
39
- ):
40
  logger.info('Initializing Discord client...')
41
  intents = discord.Intents.all()
42
  intents.message_content = True
43
  super().__init__(intents=intents, command_prefix='!')
44
 
45
- assert num_last_messages >= 1, \
46
- 'The number of last messages in context should be at least 1'
47
-
48
  self.qa_engine: QAEngine = qa_engine
49
- self.channel_ids: list[int] = DiscordClient._process_channel_ids(channel_ids)
50
- self.num_last_messages: int = num_last_messages
51
- self.use_names_in_context: bool = use_names_in_context
52
- self.enable_commands: bool = enable_commands
53
- self.debug: bool = debug
54
- self.min_messgae_len: int = 1800
 
 
55
  self.max_message_len: int = 2000
56
 
 
 
 
 
 
57
 
58
  @staticmethod
59
  def _process_channel_ids(channel_ids) -> list[int]:
@@ -103,7 +87,7 @@ class DiscordClient(discord.Client):
103
  chunks = split_text_into_chunks(
104
  text=answer,
105
  split_characters=['. ', ', ', '\n'],
106
- min_size=self.min_messgae_len,
107
  max_size=self.max_message_len
108
  )
109
  for chunk in chunks:
 
4
  import discord
5
  from typing import List
6
 
7
+ from qa_engine import logger, Config, QAEngine
8
  from discord_bot.client.utils import split_text_into_chunks
9
 
10
 
11
  class DiscordClient(discord.Client):
12
  """
13
  Discord Client class, used for interacting with a Discord server.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
  def __init__(
16
  self,
17
  qa_engine: QAEngine,
18
+ config: Config,
19
+ ):
 
 
 
 
20
  logger.info('Initializing Discord client...')
21
  intents = discord.Intents.all()
22
  intents.message_content = True
23
  super().__init__(intents=intents, command_prefix='!')
24
 
 
 
 
25
  self.qa_engine: QAEngine = qa_engine
26
+ self.channel_ids: list[int] = DiscordClient._process_channel_ids(
27
+ config.discord_channel_ids
28
+ )
29
+ self.num_last_messages: int = config.num_last_messages
30
+ self.use_names_in_context: bool = config.use_names_in_context
31
+ self.enable_commands: bool = config.enable_commands
32
+ self.debug: bool = config.debug
33
+ self.min_message_len: int = 1800
34
  self.max_message_len: int = 2000
35
 
36
+ assert all([isinstance(id, int) for id in self.channel_ids]), \
37
+ 'All channel ids should be of type int'
38
+ assert self.num_last_messages >= 1, \
39
+ 'The number of last messages in context should be at least 1'
40
+
41
 
42
  @staticmethod
43
  def _process_channel_ids(channel_ids) -> list[int]:
 
87
  chunks = split_text_into_chunks(
88
  text=answer,
89
  split_characters=['. ', ', ', '\n'],
90
+ min_size=self.min_message_len,
91
  max_size=self.max_message_len
92
  )
93
  for chunk in chunks:
qa_engine/config.py CHANGED
@@ -11,7 +11,7 @@ def get_env(env_name: str, default: Any = None, warn: bool = True) -> str:
11
  if default is not None:
12
  if warn:
13
  logger.warning(
14
- f'Environment variable {env_name} not found. ' \
15
  f'Using the default value: {default}.'
16
  )
17
  return default
@@ -34,6 +34,14 @@ class Config:
34
  use_messages_in_context: bool = eval(get_env('USE_MESSAGES_IN_CONTEXT', 'True'))
35
  debug: bool = eval(get_env('DEBUG', 'True'))
36
 
 
 
 
 
 
 
 
 
37
  # Discord bot config - optional
38
  discord_token: str = get_env('DISCORD_TOKEN', '-', warn=False)
39
  discord_channel_ids: list[int] = get_env('DISCORD_CHANNEL_IDS', field(default_factory=list), warn=False)
 
11
  if default is not None:
12
  if warn:
13
  logger.warning(
14
+ f'Environment variable {env_name} not found.' \
15
  f'Using the default value: {default}.'
16
  )
17
  return default
 
34
  use_messages_in_context: bool = eval(get_env('USE_MESSAGES_IN_CONTEXT', 'True'))
35
  debug: bool = eval(get_env('DEBUG', 'True'))
36
 
37
+ # Model config
38
+ min_new_tokens: int = int(get_env('MIN_NEW_TOKENS', 64))
39
+ max_new_tokens: int = int(get_env('MAX_NEW_TOKENS', 800))
40
+ temperature: float = float(get_env('TEMPERATURE', 0.6))
41
+ top_k: int = int(get_env('TOP_K', 50))
42
+ top_p: float = float(get_env('TOP_P', 0.95))
43
+ do_sample: bool = eval(get_env('DO_SAMPLE', 'True'))
44
+
45
  # Discord bot config - optional
46
  discord_token: str = get_env('DISCORD_TOKEN', '-', warn=False)
47
  discord_channel_ids: list[int] = get_env('DISCORD_CHANNEL_IDS', field(default_factory=list), warn=False)
qa_engine/logger.py CHANGED
@@ -1,88 +1,14 @@
1
  import logging
2
- import os
3
- import io
4
- import json
5
- from google.cloud import bigquery
6
- from google.oauth2 import service_account
7
- from google.api_core.exceptions import GoogleAPIError
8
-
9
- job_config = bigquery.LoadJobConfig(
10
- schema=[
11
- bigquery.SchemaField("timestamp", "TIMESTAMP", mode="REQUIRED"),
12
- bigquery.SchemaField("log_entry", "STRING", mode="REQUIRED"),
13
- ],
14
- write_disposition="WRITE_APPEND",
15
- )
16
-
17
-
18
- class BigQueryLoggingHandler(logging.Handler):
19
- def __init__(self):
20
- super().__init__()
21
- try:
22
- project_id = os.getenv("BIGQUERY_PROJECT_ID")
23
- dataset_id = os.getenv("BIGQUERY_DATASET_ID")
24
- table_id = os.getenv("BIGQUERY_TABLE_ID")
25
- print(f"project_id: {project_id}")
26
- print(f"dataset_id: {dataset_id}")
27
- print(f"table_id: {table_id}")
28
- service_account_info = json.loads(
29
- os.getenv("GOOGLE_SERVICE_ACCOUNT_JSON")
30
- .replace('"', "")
31
- .replace("'", '"')
32
- )
33
- print(f"service_account_info: {service_account_info}")
34
- print(f"service_account_info type: {type(service_account_info)}")
35
- print(f"service_account_info keys: {service_account_info.keys()}")
36
- credentials = service_account.Credentials.from_service_account_info(
37
- service_account_info
38
- )
39
- self.client = bigquery.Client(credentials=credentials, project=project_id)
40
- self.table_ref = self.client.dataset(dataset_id).table(table_id)
41
- except Exception as e:
42
- print(f"Error: {e}")
43
- self.handleError(e)
44
-
45
- def emit(self, record):
46
- try:
47
- recordstr = f"{self.format(record)}"
48
- body = io.BytesIO(recordstr.encode("utf-8"))
49
- job = self.client.load_table_from_file(
50
- body, self.table_ref, job_config=job_config
51
- )
52
- job.result()
53
- except GoogleAPIError as e:
54
- self.handleError(e)
55
- except Exception as e:
56
- self.handleError(e)
57
-
58
- def handleError(self, record):
59
- """
60
- Handle errors associated with logging.
61
- This method prevents logging-related exceptions from propagating.
62
- Optionally, implement more sophisticated error handling here.
63
- """
64
- if isinstance(record, logging.LogRecord):
65
- super().handleError(record)
66
- else:
67
- print(f"Logging error: {record}")
68
 
69
 
70
  logger = logging.getLogger(__name__)
71
 
72
-
73
  def setup_logger() -> None:
74
  """
75
  Logger setup.
76
  """
77
  logger.setLevel(logging.DEBUG)
78
-
79
- stream_formatter = logging.Formatter(
80
- "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
81
- )
82
- stream_handler = logging.StreamHandler()
83
- stream_handler.setFormatter(stream_formatter)
84
- logger.addHandler(stream_handler)
85
-
86
- bq_handler = BigQueryLoggingHandler()
87
- bq_handler.setFormatter(stream_formatter)
88
- logger.addHandler(bq_handler)
 
1
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  logger = logging.getLogger(__name__)
5
 
 
6
  def setup_logger() -> None:
7
  """
8
  Logger setup.
9
  """
10
  logger.setLevel(logging.DEBUG)
11
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12
+ handler = logging.StreamHandler()
13
+ handler.setFormatter(formatter)
14
+ logger.addHandler(handler)
 
 
 
 
 
 
 
qa_engine/mocks.py CHANGED
@@ -10,7 +10,7 @@ class MockLocalBinaryModel(LLM):
10
  """
11
 
12
  model_path: str = None
13
- llm: str = 'Mocked Response'
14
 
15
  def __init__(self):
16
  super().__init__()
 
10
  """
11
 
12
  model_path: str = None
13
+ llm: str = 'Warsaw'
14
 
15
  def __init__(self):
16
  super().__init__()
qa_engine/qa_engine.py CHANGED
@@ -16,7 +16,7 @@ from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings
16
  from langchain.vectorstores import FAISS
17
  from sentence_transformers import CrossEncoder
18
 
19
- from qa_engine import logger
20
  from qa_engine.response import Response
21
  from qa_engine.mocks import MockLocalBinaryModel
22
 
@@ -25,16 +25,16 @@ class LocalBinaryModel(LLM):
25
  model_id: str = None
26
  llm: None = None
27
 
28
- def __init__(self, model_id: str = None):
29
  super().__init__()
30
  # pip install llama_cpp_python==0.1.39
31
  from llama_cpp import Llama
32
 
33
- model_path = f'qa_engine/{model_id}'
34
- if not os.path.exists(model_path):
35
- raise ValueError(f'{model_path} does not exist')
36
- self.model_id = model_id
37
- self.llm = Llama(model_path=model_path, n_ctx=4096)
38
 
39
  def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
40
  output = self.llm(
@@ -58,13 +58,19 @@ class TransformersPipelineModel(LLM):
58
  model_id: str = None
59
  pipeline: str = None
60
 
61
- def __init__(self, model_id: str = None):
62
  super().__init__()
63
- self.model_id = model_id
64
-
65
- tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
 
 
 
66
  model = AutoModelForCausalLM.from_pretrained(
67
- model_id,
68
  torch_dtype=torch.bfloat16,
69
  trust_remote_code=True,
70
  load_in_8bit=False,
@@ -79,10 +85,12 @@ class TransformersPipelineModel(LLM):
79
  device_map='auto',
80
  eos_token_id=tokenizer.eos_token_id,
81
  pad_token_id=tokenizer.eos_token_id,
82
- min_new_tokens=64,
83
- max_new_tokens=800,
84
- temperature=0.1,
85
- do_sample=True,
 
 
86
  )
87
 
88
  def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
@@ -103,7 +111,7 @@ class APIServedModel(LLM):
103
  model_url: str = None
104
  debug: bool = None
105
 
106
- def __init__(self, model_url: str = None, debug: bool = None):
107
  super().__init__()
108
  if model_url[-1] == '/':
109
  raise ValueError('URL should not end with a slash - "/"')
@@ -132,66 +140,36 @@ class APIServedModel(LLM):
132
  return 'api_model'
133
 
134
 
135
-
136
  class QAEngine():
137
  """
138
  QAEngine class, used for generating answers to questions.
139
-
140
- Args:
141
- llm_model_id (str): The ID of the LLM model to be used.
142
- embedding_model_id (str): The ID of the embedding model to be used.
143
- index_repo_id (str): The ID of the index repository to be used.
144
- run_locally (bool, optional): Whether to run the models locally or on the Hugging Face hub. Defaults to True.
145
- use_docs_for_context (bool, optional): Whether to use relevant documents as context for generating answers.
146
- Defaults to True.
147
- use_messages_for_context (bool, optional): Whether to use previous messages as context for generating answers.
148
- Defaults to True.
149
- debug (bool, optional): Whether to log debug information. Defaults to False.
150
-
151
- Attributes:
152
- use_docs_for_context (bool): Whether to use relevant documents as context for generating answers.
153
- use_messages_for_context (bool): Whether to use previous messages as context for generating answers.
154
- debug (bool): Whether to log debug information.
155
- llm_model (Union[LocalBinaryModel, HuggingFacePipeline, HuggingFaceHub]): The LLM model to be used.
156
- embedding_model (Union[HuggingFaceInstructEmbeddings, HuggingFaceHubEmbeddings]): The embedding model to be used.
157
- prompt_template (PromptTemplate): The prompt template to be used.
158
- llm_chain (LLMChain): The LLM chain to be used.
159
- knowledge_index (FAISS): The FAISS index to be used.
160
-
161
  """
162
- def __init__(
163
- self,
164
- llm_model_id: str,
165
- embedding_model_id: str,
166
- index_repo_id: str,
167
- prompt_template: str,
168
- use_docs_for_context: bool = True,
169
- num_relevant_docs: int = 3,
170
- add_sources_to_response: bool = True,
171
- use_messages_for_context: bool = True,
172
- first_stage_docs: int = 50,
173
- debug: bool = False
174
- ):
175
  super().__init__()
176
- self.prompt_template = prompt_template
177
- self.use_docs_for_context = use_docs_for_context
178
- self.num_relevant_docs = num_relevant_docs
179
- self.add_sources_to_response = add_sources_to_response
180
- self.use_messages_for_context = use_messages_for_context
181
- self.first_stage_docs = first_stage_docs
182
- self.debug = debug
 
 
 
 
 
183
 
184
  prompt = PromptTemplate(
185
- template=prompt_template,
186
  input_variables=['question', 'context']
187
  )
188
- self.llm_model = QAEngine._get_model(llm_model_id)
189
  self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
190
 
191
  if self.use_docs_for_context:
192
- logger.info(f'Downloading {index_repo_id}')
193
  snapshot_download(
194
- repo_id=index_repo_id,
195
  allow_patterns=['*.faiss', '*.pkl'],
196
  repo_type='dataset',
197
  local_dir='indexes/run/'
@@ -200,7 +178,7 @@ class QAEngine():
200
  embed_instruction = 'Represent the Hugging Face library documentation'
201
  query_instruction = 'Query the most relevant piece of information from the Hugging Face documentation'
202
  embedding_model = HuggingFaceInstructEmbeddings(
203
- model_name=embedding_model_id,
204
  embed_instruction=embed_instruction,
205
  query_instruction=query_instruction
206
  )
@@ -209,27 +187,22 @@ class QAEngine():
209
  self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
210
 
211
 
212
- @staticmethod
213
- def _get_model(llm_model_id: str):
214
- if 'local_models/' in llm_model_id:
215
  logger.info('using local binary model')
216
- return LocalBinaryModel(
217
- model_id=llm_model_id
218
- )
219
- elif 'api_models/' in llm_model_id:
220
  logger.info('using api served model')
221
  return APIServedModel(
222
- model_url=llm_model_id.replace('api_models/', ''),
223
  debug=self.debug
224
  )
225
- elif llm_model_id == 'mock':
226
  logger.info('using mock model')
227
  return MockLocalBinaryModel()
228
  else:
229
  logger.info('using transformers pipeline model')
230
- return TransformersPipelineModel(
231
- model_id=llm_model_id
232
- )
233
 
234
 
235
  @staticmethod
@@ -245,7 +218,8 @@ class QAEngine():
245
  Preprocess the answer by removing unnecessary sequences and stop sequences.
246
  '''
247
  SEQUENCES_TO_REMOVE = [
248
- 'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'
 
249
  ]
250
  SEQUENCES_TO_STOP = [
251
  'User:', 'You:', 'Question:'
@@ -296,9 +270,8 @@ class QAEngine():
296
  )
297
  ]
298
  relevant_docs = relevant_docs[:self.num_relevant_docs]
299
- context += '\nEXTRACTED DOCUMENTS:\n'
300
- for i, (doc) in enumerate(relevant_docs):
301
- context += f'\n\n<DOCUMENT_{i}>\n {doc.page_content} \n</DOCUMENT_{i}>'
302
  metadata = [doc.metadata for doc in relevant_docs]
303
  response.set_sources(sources=[str(m['source']) for m in metadata])
304
 
@@ -314,7 +287,6 @@ class QAEngine():
314
  sep = '\n' + '-' * 100
315
  logger.info(f'question len: {len(question)} {sep}')
316
  logger.info(f'question: {question} {sep}')
317
- logger.info(f'question processed: {question} {sep}')
318
  logger.info(f'answer len: {len(response.get_answer())} {sep}')
319
  logger.info(f'answer original: {answer} {sep}')
320
  logger.info(f'answer postprocessed: {response.get_answer()} {sep}')
 
16
  from langchain.vectorstores import FAISS
17
  from sentence_transformers import CrossEncoder
18
 
19
+ from qa_engine import logger, Config
20
  from qa_engine.response import Response
21
  from qa_engine.mocks import MockLocalBinaryModel
22
 
 
25
  model_id: str = None
26
  llm: None = None
27
 
28
+ def __init__(self, config: Config):
29
  super().__init__()
30
  # pip install llama_cpp_python==0.1.39
31
  from llama_cpp import Llama
32
 
33
+ self.model_id = config.question_answering_model_id
34
+ self.model_path = f'qa_engine/{self.model_id}'
35
+ if not os.path.exists(self.model_path):
36
+ raise ValueError(f'{self.model_path} does not exist')
37
+ self.llm = Llama(model_path=self.model_path, n_ctx=4096)
38
 
39
  def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
40
  output = self.llm(
 
58
  model_id: str = None
59
  pipeline: str = None
60
 
61
+ def __init__(self, config: Config):
62
  super().__init__()
63
+ self.model_id = config.question_answering_model_id
64
+ self.min_new_tokens = config.min_new_tokens
65
+ self.max_new_tokens = config.max_new_tokens
66
+ self.temperature = config.temperature
67
+ self.top_k = config.top_k
68
+ self.top_p = config.top_p
69
+ self.do_sample = config.do_sample
70
+
71
+ tokenizer = AutoTokenizer.from_pretrained(self.model_id)
72
  model = AutoModelForCausalLM.from_pretrained(
73
+ self.model_id,
74
  torch_dtype=torch.bfloat16,
75
  trust_remote_code=True,
76
  load_in_8bit=False,
 
85
  device_map='auto',
86
  eos_token_id=tokenizer.eos_token_id,
87
  pad_token_id=tokenizer.eos_token_id,
88
+ min_new_tokens=self.min_new_tokens,
89
+ max_new_tokens=self.max_new_tokens,
90
+ temperature=self.temperature,
91
+ top_k=self.top_k,
92
+ top_p=self.top_p,
93
+ do_sample=self.do_sample,
94
  )
95
 
96
  def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
 
111
  model_url: str = None
112
  debug: bool = None
113
 
114
+ def __init__(self, model_url: str, debug: bool = False):
115
  super().__init__()
116
  if model_url[-1] == '/':
117
  raise ValueError('URL should not end with a slash - "/"')
 
140
  return 'api_model'
141
 
142
 
 
143
  class QAEngine():
144
  """
145
  QAEngine class, used for generating answers to questions.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  """
147
+ def __init__(self, config: Config):
 
 
 
 
 
 
 
 
 
 
 
 
148
  super().__init__()
149
+ self.config = config
150
+ self.question_answering_model_id=config.question_answering_model_id
151
+ self.embedding_model_id=config.embedding_model_id
152
+ self.index_repo_id=config.index_repo_id
153
+ self.prompt_template=config.prompt_template
154
+ self.use_docs_for_context=config.use_docs_for_context
155
+ self.num_relevant_docs=config.num_relevant_docs
156
+ self.add_sources_to_response=config.add_sources_to_response
157
+ self.use_messages_for_context=config.use_messages_in_context
158
+ self.debug=config.debug
159
+
160
+ self.first_stage_docs: int = 50
161
 
162
  prompt = PromptTemplate(
163
+ template=self.prompt_template,
164
  input_variables=['question', 'context']
165
  )
166
+ self.llm_model = self._get_model()
167
  self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
168
 
169
  if self.use_docs_for_context:
170
+ logger.info(f'Downloading {self.index_repo_id}')
171
  snapshot_download(
172
+ repo_id=self.index_repo_id,
173
  allow_patterns=['*.faiss', '*.pkl'],
174
  repo_type='dataset',
175
  local_dir='indexes/run/'
 
178
  embed_instruction = 'Represent the Hugging Face library documentation'
179
  query_instruction = 'Query the most relevant piece of information from the Hugging Face documentation'
180
  embedding_model = HuggingFaceInstructEmbeddings(
181
+ model_name=self.embedding_model_id,
182
  embed_instruction=embed_instruction,
183
  query_instruction=query_instruction
184
  )
 
187
  self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
188
 
189
 
190
+ def _get_model(self):
191
+ if 'local_models/' in self.question_answering_model_id:
 
192
  logger.info('using local binary model')
193
+ return LocalBinaryModel(self.config)
194
+ elif 'api_models/' in self.question_answering_model_id:
 
 
195
  logger.info('using api served model')
196
  return APIServedModel(
197
+ model_url=self.question_answering_model_id.replace('api_models/', ''),
198
  debug=self.debug
199
  )
200
+ elif self.question_answering_model_id == 'mock':
201
  logger.info('using mock model')
202
  return MockLocalBinaryModel()
203
  else:
204
  logger.info('using transformers pipeline model')
205
+ return TransformersPipelineModel(self.config)
 
 
206
 
207
 
208
  @staticmethod
 
218
  Preprocess the answer by removing unnecessary sequences and stop sequences.
219
  '''
220
  SEQUENCES_TO_REMOVE = [
221
+ 'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]',
222
+ '<context>', '<\context>', '<question>', '<\question>',
223
  ]
224
  SEQUENCES_TO_STOP = [
225
  'User:', 'You:', 'Question:'
 
270
  )
271
  ]
272
  relevant_docs = relevant_docs[:self.num_relevant_docs]
273
+ context += '\nExtracted documents:\n'
274
+ context += ''.join([doc.page_content for doc in relevant_docs])
 
275
  metadata = [doc.metadata for doc in relevant_docs]
276
  response.set_sources(sources=[str(m['source']) for m in metadata])
277
 
 
287
  sep = '\n' + '-' * 100
288
  logger.info(f'question len: {len(question)} {sep}')
289
  logger.info(f'question: {question} {sep}')
 
290
  logger.info(f'answer len: {len(response.get_answer())} {sep}')
291
  logger.info(f'answer original: {answer} {sep}')
292
  logger.info(f'answer postprocessed: {response.get_answer()} {sep}')
requirements.txt CHANGED
@@ -26,4 +26,3 @@ InstructorEmbedding==1.0.0
26
  faiss_cpu==1.7.3
27
  uvicorn==0.22.0
28
  pytest==7.3.1
29
- google-cloud-bigquery==3.17.2
 
26
  faiss_cpu==1.7.3
27
  uvicorn==0.22.0
28
  pytest==7.3.1