Spaces:
Runtime error
Runtime error
KonradSzafer
commited on
Commit
•
988981a
1
Parent(s):
5195c5a
config update
Browse files- api/__main__.py +1 -11
- app.py +2 -15
- benchmark/__main__.py +1 -10
- config/.env.example +11 -3
- discord_bot/__main__.py +2 -16
- discord_bot/client/client.py +17 -33
- qa_engine/config.py +9 -1
- qa_engine/logger.py +4 -78
- qa_engine/mocks.py +1 -1
- qa_engine/qa_engine.py +54 -82
- requirements.txt +0 -1
api/__main__.py
CHANGED
@@ -6,17 +6,7 @@ from qa_engine import logger, Config, QAEngine
|
|
6 |
|
7 |
config = Config()
|
8 |
app = FastAPI()
|
9 |
-
qa_engine = QAEngine(
|
10 |
-
llm_model_id=config.question_answering_model_id,
|
11 |
-
embedding_model_id=config.embedding_model_id,
|
12 |
-
index_repo_id=config.index_repo_id,
|
13 |
-
prompt_template=config.prompt_template,
|
14 |
-
use_docs_for_context=config.use_docs_for_context,
|
15 |
-
num_relevant_docs=config.num_relevant_docs,
|
16 |
-
add_sources_to_response=config.add_sources_to_response,
|
17 |
-
use_messages_for_context=config.use_messages_in_context,
|
18 |
-
debug=config.debug
|
19 |
-
)
|
20 |
|
21 |
|
22 |
@app.get('/')
|
|
|
6 |
|
7 |
config = Config()
|
8 |
app = FastAPI()
|
9 |
+
qa_engine = QAEngine(config=config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
@app.get('/')
|
app.py
CHANGED
@@ -8,16 +8,7 @@ from discord_bot import DiscordClient
|
|
8 |
|
9 |
|
10 |
config = Config()
|
11 |
-
qa_engine = QAEngine(
|
12 |
-
llm_model_id=config.question_answering_model_id,
|
13 |
-
embedding_model_id=config.embedding_model_id,
|
14 |
-
index_repo_id=config.index_repo_id,
|
15 |
-
prompt_template=config.prompt_template,
|
16 |
-
use_docs_for_context=config.use_docs_for_context,
|
17 |
-
add_sources_to_response=config.add_sources_to_response,
|
18 |
-
use_messages_for_context=config.use_messages_in_context,
|
19 |
-
debug=config.debug
|
20 |
-
)
|
21 |
|
22 |
|
23 |
def gradio_interface():
|
@@ -41,11 +32,7 @@ def gradio_interface():
|
|
41 |
def discord_bot_inference_thread():
|
42 |
client = DiscordClient(
|
43 |
qa_engine=qa_engine,
|
44 |
-
|
45 |
-
num_last_messages=config.num_last_messages,
|
46 |
-
use_names_in_context=config.use_names_in_context,
|
47 |
-
enable_commands=config.enable_commands,
|
48 |
-
debug=config.debug
|
49 |
)
|
50 |
client.run(config.discord_token)
|
51 |
|
|
|
8 |
|
9 |
|
10 |
config = Config()
|
11 |
+
qa_engine = QAEngine(config=config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
def gradio_interface():
|
|
|
32 |
def discord_bot_inference_thread():
|
33 |
client = DiscordClient(
|
34 |
qa_engine=qa_engine,
|
35 |
+
config=config
|
|
|
|
|
|
|
|
|
36 |
)
|
37 |
client.run(config.discord_token)
|
38 |
|
benchmark/__main__.py
CHANGED
@@ -10,16 +10,7 @@ from qa_engine import logger, Config, QAEngine
|
|
10 |
QUESTIONS_FILENAME = 'benchmark/questions.json'
|
11 |
|
12 |
config = Config()
|
13 |
-
qa_engine = QAEngine(
|
14 |
-
llm_model_id=config.question_answering_model_id,
|
15 |
-
embedding_model_id=config.embedding_model_id,
|
16 |
-
index_repo_id=config.index_repo_id,
|
17 |
-
prompt_template=config.prompt_template,
|
18 |
-
use_docs_for_context=config.use_docs_for_context,
|
19 |
-
add_sources_to_response=config.add_sources_to_response,
|
20 |
-
use_messages_for_context=config.use_messages_in_context,
|
21 |
-
debug=config.debug
|
22 |
-
)
|
23 |
|
24 |
|
25 |
def main():
|
|
|
10 |
QUESTIONS_FILENAME = 'benchmark/questions.json'
|
11 |
|
12 |
config = Config()
|
13 |
+
qa_engine = QAEngine(config=config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
def main():
|
config/.env.example
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# QA engine settings
|
2 |
-
QUESTION_ANSWERING_MODEL_ID=
|
3 |
-
EMBEDDING_MODEL_ID=
|
4 |
-
INDEX_REPO_ID=
|
5 |
PROMPT_TEMPLATE_NAME=llama
|
6 |
USE_DOCS_FOR_CONTEXT=True
|
7 |
NUM_RELEVANT_DOCS=4
|
@@ -9,6 +9,14 @@ ADD_SOURCES_TO_RESPONSE=True
|
|
9 |
USE_MESSAGES_IN_CONTEXT=True
|
10 |
DEBUG=True
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# Discord settings
|
13 |
DISCORD_TOKEN=your-bot-token
|
14 |
NUM_LAST_MESSAGES=1
|
|
|
1 |
# QA engine settings
|
2 |
+
QUESTION_ANSWERING_MODEL_ID=mock
|
3 |
+
EMBEDDING_MODEL_ID=hkunlp/instructor-large
|
4 |
+
INDEX_REPO_ID=KonradSzafer/index-instructor-large-812-m512-all_repos_above_50_stars
|
5 |
PROMPT_TEMPLATE_NAME=llama
|
6 |
USE_DOCS_FOR_CONTEXT=True
|
7 |
NUM_RELEVANT_DOCS=4
|
|
|
9 |
USE_MESSAGES_IN_CONTEXT=True
|
10 |
DEBUG=True
|
11 |
|
12 |
+
# Model settings
|
13 |
+
MIN_NEW_TOKENS=64
|
14 |
+
MAX_NEW_TOKENS=800
|
15 |
+
TEMPERATURE=0.6
|
16 |
+
TOP_K=50
|
17 |
+
TOP_P=0.9
|
18 |
+
DO_SAMPLE=True
|
19 |
+
|
20 |
# Discord settings
|
21 |
DISCORD_TOKEN=your-bot-token
|
22 |
NUM_LAST_MESSAGES=1
|
discord_bot/__main__.py
CHANGED
@@ -3,24 +3,10 @@ from discord_bot.client import DiscordClient
|
|
3 |
|
4 |
|
5 |
config = Config()
|
6 |
-
qa_engine = QAEngine(
|
7 |
-
llm_model_id=config.question_answering_model_id,
|
8 |
-
embedding_model_id=config.embedding_model_id,
|
9 |
-
index_repo_id=config.index_repo_id,
|
10 |
-
prompt_template=config.prompt_template,
|
11 |
-
use_docs_for_context=config.use_docs_for_context,
|
12 |
-
num_relevant_docs=config.num_relevant_docs,
|
13 |
-
add_sources_to_response=config.add_sources_to_response,
|
14 |
-
use_messages_for_context=config.use_messages_in_context,
|
15 |
-
debug=config.debug
|
16 |
-
)
|
17 |
client = DiscordClient(
|
18 |
qa_engine=qa_engine,
|
19 |
-
|
20 |
-
num_last_messages=config.num_last_messages,
|
21 |
-
use_names_in_context=config.use_names_in_context,
|
22 |
-
enable_commands=config.enable_commands,
|
23 |
-
debug=config.debug
|
24 |
)
|
25 |
|
26 |
|
|
|
3 |
|
4 |
|
5 |
config = Config()
|
6 |
+
qa_engine = QAEngine(config=config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
client = DiscordClient(
|
8 |
qa_engine=qa_engine,
|
9 |
+
config=config
|
|
|
|
|
|
|
|
|
10 |
)
|
11 |
|
12 |
|
discord_bot/client/client.py
CHANGED
@@ -4,56 +4,40 @@ from urllib.parse import quote
|
|
4 |
import discord
|
5 |
from typing import List
|
6 |
|
7 |
-
from qa_engine import logger, QAEngine
|
8 |
from discord_bot.client.utils import split_text_into_chunks
|
9 |
|
10 |
|
11 |
class DiscordClient(discord.Client):
|
12 |
"""
|
13 |
Discord Client class, used for interacting with a Discord server.
|
14 |
-
|
15 |
-
Args:
|
16 |
-
qa_service_url (str): The URL of the question answering service.
|
17 |
-
num_last_messages (int, optional): The number of previous messages to use as context for generating answers.
|
18 |
-
Defaults to 5.
|
19 |
-
use_names_in_context (bool, optional): Whether to include user names in the message context. Defaults to True.
|
20 |
-
enable_commands (bool, optional): Whether to enable commands for the bot. Defaults to True.
|
21 |
-
|
22 |
-
Attributes:
|
23 |
-
qa_service_url (str): The URL of the question answering service.
|
24 |
-
num_last_messages (int): The number of previous messages to use as context for generating answers.
|
25 |
-
use_names_in_context (bool): Whether to include user names in the message context.
|
26 |
-
enable_commands (bool): Whether to enable commands for the bot.
|
27 |
-
max_message_len (int): The maximum length of a message.
|
28 |
-
system_prompt (str): The system prompt to be used.
|
29 |
-
|
30 |
"""
|
31 |
def __init__(
|
32 |
self,
|
33 |
qa_engine: QAEngine,
|
34 |
-
|
35 |
-
|
36 |
-
use_names_in_context: bool = True,
|
37 |
-
enable_commands: bool = True,
|
38 |
-
debug: bool = False
|
39 |
-
):
|
40 |
logger.info('Initializing Discord client...')
|
41 |
intents = discord.Intents.all()
|
42 |
intents.message_content = True
|
43 |
super().__init__(intents=intents, command_prefix='!')
|
44 |
|
45 |
-
assert num_last_messages >= 1, \
|
46 |
-
'The number of last messages in context should be at least 1'
|
47 |
-
|
48 |
self.qa_engine: QAEngine = qa_engine
|
49 |
-
self.channel_ids: list[int] = DiscordClient._process_channel_ids(
|
50 |
-
|
51 |
-
|
52 |
-
self.
|
53 |
-
self.
|
54 |
-
self.
|
|
|
|
|
55 |
self.max_message_len: int = 2000
|
56 |
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
@staticmethod
|
59 |
def _process_channel_ids(channel_ids) -> list[int]:
|
@@ -103,7 +87,7 @@ class DiscordClient(discord.Client):
|
|
103 |
chunks = split_text_into_chunks(
|
104 |
text=answer,
|
105 |
split_characters=['. ', ', ', '\n'],
|
106 |
-
min_size=self.
|
107 |
max_size=self.max_message_len
|
108 |
)
|
109 |
for chunk in chunks:
|
|
|
4 |
import discord
|
5 |
from typing import List
|
6 |
|
7 |
+
from qa_engine import logger, Config, QAEngine
|
8 |
from discord_bot.client.utils import split_text_into_chunks
|
9 |
|
10 |
|
11 |
class DiscordClient(discord.Client):
|
12 |
"""
|
13 |
Discord Client class, used for interacting with a Discord server.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
"""
|
15 |
def __init__(
|
16 |
self,
|
17 |
qa_engine: QAEngine,
|
18 |
+
config: Config,
|
19 |
+
):
|
|
|
|
|
|
|
|
|
20 |
logger.info('Initializing Discord client...')
|
21 |
intents = discord.Intents.all()
|
22 |
intents.message_content = True
|
23 |
super().__init__(intents=intents, command_prefix='!')
|
24 |
|
|
|
|
|
|
|
25 |
self.qa_engine: QAEngine = qa_engine
|
26 |
+
self.channel_ids: list[int] = DiscordClient._process_channel_ids(
|
27 |
+
config.discord_channel_ids
|
28 |
+
)
|
29 |
+
self.num_last_messages: int = config.num_last_messages
|
30 |
+
self.use_names_in_context: bool = config.use_names_in_context
|
31 |
+
self.enable_commands: bool = config.enable_commands
|
32 |
+
self.debug: bool = config.debug
|
33 |
+
self.min_message_len: int = 1800
|
34 |
self.max_message_len: int = 2000
|
35 |
|
36 |
+
assert all([isinstance(id, int) for id in self.channel_ids]), \
|
37 |
+
'All channel ids should be of type int'
|
38 |
+
assert self.num_last_messages >= 1, \
|
39 |
+
'The number of last messages in context should be at least 1'
|
40 |
+
|
41 |
|
42 |
@staticmethod
|
43 |
def _process_channel_ids(channel_ids) -> list[int]:
|
|
|
87 |
chunks = split_text_into_chunks(
|
88 |
text=answer,
|
89 |
split_characters=['. ', ', ', '\n'],
|
90 |
+
min_size=self.min_message_len,
|
91 |
max_size=self.max_message_len
|
92 |
)
|
93 |
for chunk in chunks:
|
qa_engine/config.py
CHANGED
@@ -11,7 +11,7 @@ def get_env(env_name: str, default: Any = None, warn: bool = True) -> str:
|
|
11 |
if default is not None:
|
12 |
if warn:
|
13 |
logger.warning(
|
14 |
-
f'Environment variable {env_name} not found.
|
15 |
f'Using the default value: {default}.'
|
16 |
)
|
17 |
return default
|
@@ -34,6 +34,14 @@ class Config:
|
|
34 |
use_messages_in_context: bool = eval(get_env('USE_MESSAGES_IN_CONTEXT', 'True'))
|
35 |
debug: bool = eval(get_env('DEBUG', 'True'))
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
# Discord bot config - optional
|
38 |
discord_token: str = get_env('DISCORD_TOKEN', '-', warn=False)
|
39 |
discord_channel_ids: list[int] = get_env('DISCORD_CHANNEL_IDS', field(default_factory=list), warn=False)
|
|
|
11 |
if default is not None:
|
12 |
if warn:
|
13 |
logger.warning(
|
14 |
+
f'Environment variable {env_name} not found.' \
|
15 |
f'Using the default value: {default}.'
|
16 |
)
|
17 |
return default
|
|
|
34 |
use_messages_in_context: bool = eval(get_env('USE_MESSAGES_IN_CONTEXT', 'True'))
|
35 |
debug: bool = eval(get_env('DEBUG', 'True'))
|
36 |
|
37 |
+
# Model config
|
38 |
+
min_new_tokens: int = int(get_env('MIN_NEW_TOKENS', 64))
|
39 |
+
max_new_tokens: int = int(get_env('MAX_NEW_TOKENS', 800))
|
40 |
+
temperature: float = float(get_env('TEMPERATURE', 0.6))
|
41 |
+
top_k: int = int(get_env('TOP_K', 50))
|
42 |
+
top_p: float = float(get_env('TOP_P', 0.95))
|
43 |
+
do_sample: bool = eval(get_env('DO_SAMPLE', 'True'))
|
44 |
+
|
45 |
# Discord bot config - optional
|
46 |
discord_token: str = get_env('DISCORD_TOKEN', '-', warn=False)
|
47 |
discord_channel_ids: list[int] = get_env('DISCORD_CHANNEL_IDS', field(default_factory=list), warn=False)
|
qa_engine/logger.py
CHANGED
@@ -1,88 +1,14 @@
|
|
1 |
import logging
|
2 |
-
import os
|
3 |
-
import io
|
4 |
-
import json
|
5 |
-
from google.cloud import bigquery
|
6 |
-
from google.oauth2 import service_account
|
7 |
-
from google.api_core.exceptions import GoogleAPIError
|
8 |
-
|
9 |
-
job_config = bigquery.LoadJobConfig(
|
10 |
-
schema=[
|
11 |
-
bigquery.SchemaField("timestamp", "TIMESTAMP", mode="REQUIRED"),
|
12 |
-
bigquery.SchemaField("log_entry", "STRING", mode="REQUIRED"),
|
13 |
-
],
|
14 |
-
write_disposition="WRITE_APPEND",
|
15 |
-
)
|
16 |
-
|
17 |
-
|
18 |
-
class BigQueryLoggingHandler(logging.Handler):
|
19 |
-
def __init__(self):
|
20 |
-
super().__init__()
|
21 |
-
try:
|
22 |
-
project_id = os.getenv("BIGQUERY_PROJECT_ID")
|
23 |
-
dataset_id = os.getenv("BIGQUERY_DATASET_ID")
|
24 |
-
table_id = os.getenv("BIGQUERY_TABLE_ID")
|
25 |
-
print(f"project_id: {project_id}")
|
26 |
-
print(f"dataset_id: {dataset_id}")
|
27 |
-
print(f"table_id: {table_id}")
|
28 |
-
service_account_info = json.loads(
|
29 |
-
os.getenv("GOOGLE_SERVICE_ACCOUNT_JSON")
|
30 |
-
.replace('"', "")
|
31 |
-
.replace("'", '"')
|
32 |
-
)
|
33 |
-
print(f"service_account_info: {service_account_info}")
|
34 |
-
print(f"service_account_info type: {type(service_account_info)}")
|
35 |
-
print(f"service_account_info keys: {service_account_info.keys()}")
|
36 |
-
credentials = service_account.Credentials.from_service_account_info(
|
37 |
-
service_account_info
|
38 |
-
)
|
39 |
-
self.client = bigquery.Client(credentials=credentials, project=project_id)
|
40 |
-
self.table_ref = self.client.dataset(dataset_id).table(table_id)
|
41 |
-
except Exception as e:
|
42 |
-
print(f"Error: {e}")
|
43 |
-
self.handleError(e)
|
44 |
-
|
45 |
-
def emit(self, record):
|
46 |
-
try:
|
47 |
-
recordstr = f"{self.format(record)}"
|
48 |
-
body = io.BytesIO(recordstr.encode("utf-8"))
|
49 |
-
job = self.client.load_table_from_file(
|
50 |
-
body, self.table_ref, job_config=job_config
|
51 |
-
)
|
52 |
-
job.result()
|
53 |
-
except GoogleAPIError as e:
|
54 |
-
self.handleError(e)
|
55 |
-
except Exception as e:
|
56 |
-
self.handleError(e)
|
57 |
-
|
58 |
-
def handleError(self, record):
|
59 |
-
"""
|
60 |
-
Handle errors associated with logging.
|
61 |
-
This method prevents logging-related exceptions from propagating.
|
62 |
-
Optionally, implement more sophisticated error handling here.
|
63 |
-
"""
|
64 |
-
if isinstance(record, logging.LogRecord):
|
65 |
-
super().handleError(record)
|
66 |
-
else:
|
67 |
-
print(f"Logging error: {record}")
|
68 |
|
69 |
|
70 |
logger = logging.getLogger(__name__)
|
71 |
|
72 |
-
|
73 |
def setup_logger() -> None:
|
74 |
"""
|
75 |
Logger setup.
|
76 |
"""
|
77 |
logger.setLevel(logging.DEBUG)
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
)
|
82 |
-
stream_handler = logging.StreamHandler()
|
83 |
-
stream_handler.setFormatter(stream_formatter)
|
84 |
-
logger.addHandler(stream_handler)
|
85 |
-
|
86 |
-
bq_handler = BigQueryLoggingHandler()
|
87 |
-
bq_handler.setFormatter(stream_formatter)
|
88 |
-
logger.addHandler(bq_handler)
|
|
|
1 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
5 |
|
|
|
6 |
def setup_logger() -> None:
|
7 |
"""
|
8 |
Logger setup.
|
9 |
"""
|
10 |
logger.setLevel(logging.DEBUG)
|
11 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
12 |
+
handler = logging.StreamHandler()
|
13 |
+
handler.setFormatter(formatter)
|
14 |
+
logger.addHandler(handler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qa_engine/mocks.py
CHANGED
@@ -10,7 +10,7 @@ class MockLocalBinaryModel(LLM):
|
|
10 |
"""
|
11 |
|
12 |
model_path: str = None
|
13 |
-
llm: str = '
|
14 |
|
15 |
def __init__(self):
|
16 |
super().__init__()
|
|
|
10 |
"""
|
11 |
|
12 |
model_path: str = None
|
13 |
+
llm: str = 'Warsaw'
|
14 |
|
15 |
def __init__(self):
|
16 |
super().__init__()
|
qa_engine/qa_engine.py
CHANGED
@@ -16,7 +16,7 @@ from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings
|
|
16 |
from langchain.vectorstores import FAISS
|
17 |
from sentence_transformers import CrossEncoder
|
18 |
|
19 |
-
from qa_engine import logger
|
20 |
from qa_engine.response import Response
|
21 |
from qa_engine.mocks import MockLocalBinaryModel
|
22 |
|
@@ -25,16 +25,16 @@ class LocalBinaryModel(LLM):
|
|
25 |
model_id: str = None
|
26 |
llm: None = None
|
27 |
|
28 |
-
def __init__(self,
|
29 |
super().__init__()
|
30 |
# pip install llama_cpp_python==0.1.39
|
31 |
from llama_cpp import Llama
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
self.llm = Llama(model_path=model_path, n_ctx=4096)
|
38 |
|
39 |
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
40 |
output = self.llm(
|
@@ -58,13 +58,19 @@ class TransformersPipelineModel(LLM):
|
|
58 |
model_id: str = None
|
59 |
pipeline: str = None
|
60 |
|
61 |
-
def __init__(self,
|
62 |
super().__init__()
|
63 |
-
self.model_id =
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
model = AutoModelForCausalLM.from_pretrained(
|
67 |
-
model_id,
|
68 |
torch_dtype=torch.bfloat16,
|
69 |
trust_remote_code=True,
|
70 |
load_in_8bit=False,
|
@@ -79,10 +85,12 @@ class TransformersPipelineModel(LLM):
|
|
79 |
device_map='auto',
|
80 |
eos_token_id=tokenizer.eos_token_id,
|
81 |
pad_token_id=tokenizer.eos_token_id,
|
82 |
-
min_new_tokens=
|
83 |
-
max_new_tokens=
|
84 |
-
temperature=
|
85 |
-
|
|
|
|
|
86 |
)
|
87 |
|
88 |
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
@@ -103,7 +111,7 @@ class APIServedModel(LLM):
|
|
103 |
model_url: str = None
|
104 |
debug: bool = None
|
105 |
|
106 |
-
def __init__(self, model_url: str
|
107 |
super().__init__()
|
108 |
if model_url[-1] == '/':
|
109 |
raise ValueError('URL should not end with a slash - "/"')
|
@@ -132,66 +140,36 @@ class APIServedModel(LLM):
|
|
132 |
return 'api_model'
|
133 |
|
134 |
|
135 |
-
|
136 |
class QAEngine():
|
137 |
"""
|
138 |
QAEngine class, used for generating answers to questions.
|
139 |
-
|
140 |
-
Args:
|
141 |
-
llm_model_id (str): The ID of the LLM model to be used.
|
142 |
-
embedding_model_id (str): The ID of the embedding model to be used.
|
143 |
-
index_repo_id (str): The ID of the index repository to be used.
|
144 |
-
run_locally (bool, optional): Whether to run the models locally or on the Hugging Face hub. Defaults to True.
|
145 |
-
use_docs_for_context (bool, optional): Whether to use relevant documents as context for generating answers.
|
146 |
-
Defaults to True.
|
147 |
-
use_messages_for_context (bool, optional): Whether to use previous messages as context for generating answers.
|
148 |
-
Defaults to True.
|
149 |
-
debug (bool, optional): Whether to log debug information. Defaults to False.
|
150 |
-
|
151 |
-
Attributes:
|
152 |
-
use_docs_for_context (bool): Whether to use relevant documents as context for generating answers.
|
153 |
-
use_messages_for_context (bool): Whether to use previous messages as context for generating answers.
|
154 |
-
debug (bool): Whether to log debug information.
|
155 |
-
llm_model (Union[LocalBinaryModel, HuggingFacePipeline, HuggingFaceHub]): The LLM model to be used.
|
156 |
-
embedding_model (Union[HuggingFaceInstructEmbeddings, HuggingFaceHubEmbeddings]): The embedding model to be used.
|
157 |
-
prompt_template (PromptTemplate): The prompt template to be used.
|
158 |
-
llm_chain (LLMChain): The LLM chain to be used.
|
159 |
-
knowledge_index (FAISS): The FAISS index to be used.
|
160 |
-
|
161 |
"""
|
162 |
-
def __init__(
|
163 |
-
self,
|
164 |
-
llm_model_id: str,
|
165 |
-
embedding_model_id: str,
|
166 |
-
index_repo_id: str,
|
167 |
-
prompt_template: str,
|
168 |
-
use_docs_for_context: bool = True,
|
169 |
-
num_relevant_docs: int = 3,
|
170 |
-
add_sources_to_response: bool = True,
|
171 |
-
use_messages_for_context: bool = True,
|
172 |
-
first_stage_docs: int = 50,
|
173 |
-
debug: bool = False
|
174 |
-
):
|
175 |
super().__init__()
|
176 |
-
self.
|
177 |
-
self.
|
178 |
-
self.
|
179 |
-
self.
|
180 |
-
self.
|
181 |
-
self.
|
182 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
prompt = PromptTemplate(
|
185 |
-
template=prompt_template,
|
186 |
input_variables=['question', 'context']
|
187 |
)
|
188 |
-
self.llm_model =
|
189 |
self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
|
190 |
|
191 |
if self.use_docs_for_context:
|
192 |
-
logger.info(f'Downloading {index_repo_id}')
|
193 |
snapshot_download(
|
194 |
-
repo_id=index_repo_id,
|
195 |
allow_patterns=['*.faiss', '*.pkl'],
|
196 |
repo_type='dataset',
|
197 |
local_dir='indexes/run/'
|
@@ -200,7 +178,7 @@ class QAEngine():
|
|
200 |
embed_instruction = 'Represent the Hugging Face library documentation'
|
201 |
query_instruction = 'Query the most relevant piece of information from the Hugging Face documentation'
|
202 |
embedding_model = HuggingFaceInstructEmbeddings(
|
203 |
-
model_name=embedding_model_id,
|
204 |
embed_instruction=embed_instruction,
|
205 |
query_instruction=query_instruction
|
206 |
)
|
@@ -209,27 +187,22 @@ class QAEngine():
|
|
209 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
210 |
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
if 'local_models/' in llm_model_id:
|
215 |
logger.info('using local binary model')
|
216 |
-
return LocalBinaryModel(
|
217 |
-
|
218 |
-
)
|
219 |
-
elif 'api_models/' in llm_model_id:
|
220 |
logger.info('using api served model')
|
221 |
return APIServedModel(
|
222 |
-
model_url=
|
223 |
debug=self.debug
|
224 |
)
|
225 |
-
elif
|
226 |
logger.info('using mock model')
|
227 |
return MockLocalBinaryModel()
|
228 |
else:
|
229 |
logger.info('using transformers pipeline model')
|
230 |
-
return TransformersPipelineModel(
|
231 |
-
model_id=llm_model_id
|
232 |
-
)
|
233 |
|
234 |
|
235 |
@staticmethod
|
@@ -245,7 +218,8 @@ class QAEngine():
|
|
245 |
Preprocess the answer by removing unnecessary sequences and stop sequences.
|
246 |
'''
|
247 |
SEQUENCES_TO_REMOVE = [
|
248 |
-
'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'
|
|
|
249 |
]
|
250 |
SEQUENCES_TO_STOP = [
|
251 |
'User:', 'You:', 'Question:'
|
@@ -296,9 +270,8 @@ class QAEngine():
|
|
296 |
)
|
297 |
]
|
298 |
relevant_docs = relevant_docs[:self.num_relevant_docs]
|
299 |
-
context += '\
|
300 |
-
|
301 |
-
context += f'\n\n<DOCUMENT_{i}>\n {doc.page_content} \n</DOCUMENT_{i}>'
|
302 |
metadata = [doc.metadata for doc in relevant_docs]
|
303 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
304 |
|
@@ -314,7 +287,6 @@ class QAEngine():
|
|
314 |
sep = '\n' + '-' * 100
|
315 |
logger.info(f'question len: {len(question)} {sep}')
|
316 |
logger.info(f'question: {question} {sep}')
|
317 |
-
logger.info(f'question processed: {question} {sep}')
|
318 |
logger.info(f'answer len: {len(response.get_answer())} {sep}')
|
319 |
logger.info(f'answer original: {answer} {sep}')
|
320 |
logger.info(f'answer postprocessed: {response.get_answer()} {sep}')
|
|
|
16 |
from langchain.vectorstores import FAISS
|
17 |
from sentence_transformers import CrossEncoder
|
18 |
|
19 |
+
from qa_engine import logger, Config
|
20 |
from qa_engine.response import Response
|
21 |
from qa_engine.mocks import MockLocalBinaryModel
|
22 |
|
|
|
25 |
model_id: str = None
|
26 |
llm: None = None
|
27 |
|
28 |
+
def __init__(self, config: Config):
|
29 |
super().__init__()
|
30 |
# pip install llama_cpp_python==0.1.39
|
31 |
from llama_cpp import Llama
|
32 |
|
33 |
+
self.model_id = config.question_answering_model_id
|
34 |
+
self.model_path = f'qa_engine/{self.model_id}'
|
35 |
+
if not os.path.exists(self.model_path):
|
36 |
+
raise ValueError(f'{self.model_path} does not exist')
|
37 |
+
self.llm = Llama(model_path=self.model_path, n_ctx=4096)
|
38 |
|
39 |
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
40 |
output = self.llm(
|
|
|
58 |
model_id: str = None
|
59 |
pipeline: str = None
|
60 |
|
61 |
+
def __init__(self, config: Config):
|
62 |
super().__init__()
|
63 |
+
self.model_id = config.question_answering_model_id
|
64 |
+
self.min_new_tokens = config.min_new_tokens
|
65 |
+
self.max_new_tokens = config.max_new_tokens
|
66 |
+
self.temperature = config.temperature
|
67 |
+
self.top_k = config.top_k
|
68 |
+
self.top_p = config.top_p
|
69 |
+
self.do_sample = config.do_sample
|
70 |
+
|
71 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
72 |
model = AutoModelForCausalLM.from_pretrained(
|
73 |
+
self.model_id,
|
74 |
torch_dtype=torch.bfloat16,
|
75 |
trust_remote_code=True,
|
76 |
load_in_8bit=False,
|
|
|
85 |
device_map='auto',
|
86 |
eos_token_id=tokenizer.eos_token_id,
|
87 |
pad_token_id=tokenizer.eos_token_id,
|
88 |
+
min_new_tokens=self.min_new_tokens,
|
89 |
+
max_new_tokens=self.max_new_tokens,
|
90 |
+
temperature=self.temperature,
|
91 |
+
top_k=self.top_k,
|
92 |
+
top_p=self.top_p,
|
93 |
+
do_sample=self.do_sample,
|
94 |
)
|
95 |
|
96 |
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
|
|
111 |
model_url: str = None
|
112 |
debug: bool = None
|
113 |
|
114 |
+
def __init__(self, model_url: str, debug: bool = False):
|
115 |
super().__init__()
|
116 |
if model_url[-1] == '/':
|
117 |
raise ValueError('URL should not end with a slash - "/"')
|
|
|
140 |
return 'api_model'
|
141 |
|
142 |
|
|
|
143 |
class QAEngine():
|
144 |
"""
|
145 |
QAEngine class, used for generating answers to questions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
"""
|
147 |
+
def __init__(self, config: Config):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
super().__init__()
|
149 |
+
self.config = config
|
150 |
+
self.question_answering_model_id=config.question_answering_model_id
|
151 |
+
self.embedding_model_id=config.embedding_model_id
|
152 |
+
self.index_repo_id=config.index_repo_id
|
153 |
+
self.prompt_template=config.prompt_template
|
154 |
+
self.use_docs_for_context=config.use_docs_for_context
|
155 |
+
self.num_relevant_docs=config.num_relevant_docs
|
156 |
+
self.add_sources_to_response=config.add_sources_to_response
|
157 |
+
self.use_messages_for_context=config.use_messages_in_context
|
158 |
+
self.debug=config.debug
|
159 |
+
|
160 |
+
self.first_stage_docs: int = 50
|
161 |
|
162 |
prompt = PromptTemplate(
|
163 |
+
template=self.prompt_template,
|
164 |
input_variables=['question', 'context']
|
165 |
)
|
166 |
+
self.llm_model = self._get_model()
|
167 |
self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
|
168 |
|
169 |
if self.use_docs_for_context:
|
170 |
+
logger.info(f'Downloading {self.index_repo_id}')
|
171 |
snapshot_download(
|
172 |
+
repo_id=self.index_repo_id,
|
173 |
allow_patterns=['*.faiss', '*.pkl'],
|
174 |
repo_type='dataset',
|
175 |
local_dir='indexes/run/'
|
|
|
178 |
embed_instruction = 'Represent the Hugging Face library documentation'
|
179 |
query_instruction = 'Query the most relevant piece of information from the Hugging Face documentation'
|
180 |
embedding_model = HuggingFaceInstructEmbeddings(
|
181 |
+
model_name=self.embedding_model_id,
|
182 |
embed_instruction=embed_instruction,
|
183 |
query_instruction=query_instruction
|
184 |
)
|
|
|
187 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
188 |
|
189 |
|
190 |
+
def _get_model(self):
|
191 |
+
if 'local_models/' in self.question_answering_model_id:
|
|
|
192 |
logger.info('using local binary model')
|
193 |
+
return LocalBinaryModel(self.config)
|
194 |
+
elif 'api_models/' in self.question_answering_model_id:
|
|
|
|
|
195 |
logger.info('using api served model')
|
196 |
return APIServedModel(
|
197 |
+
model_url=self.question_answering_model_id.replace('api_models/', ''),
|
198 |
debug=self.debug
|
199 |
)
|
200 |
+
elif self.question_answering_model_id == 'mock':
|
201 |
logger.info('using mock model')
|
202 |
return MockLocalBinaryModel()
|
203 |
else:
|
204 |
logger.info('using transformers pipeline model')
|
205 |
+
return TransformersPipelineModel(self.config)
|
|
|
|
|
206 |
|
207 |
|
208 |
@staticmethod
|
|
|
218 |
Preprocess the answer by removing unnecessary sequences and stop sequences.
|
219 |
'''
|
220 |
SEQUENCES_TO_REMOVE = [
|
221 |
+
'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]',
|
222 |
+
'<context>', '<\context>', '<question>', '<\question>',
|
223 |
]
|
224 |
SEQUENCES_TO_STOP = [
|
225 |
'User:', 'You:', 'Question:'
|
|
|
270 |
)
|
271 |
]
|
272 |
relevant_docs = relevant_docs[:self.num_relevant_docs]
|
273 |
+
context += '\nExtracted documents:\n'
|
274 |
+
context += ''.join([doc.page_content for doc in relevant_docs])
|
|
|
275 |
metadata = [doc.metadata for doc in relevant_docs]
|
276 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
277 |
|
|
|
287 |
sep = '\n' + '-' * 100
|
288 |
logger.info(f'question len: {len(question)} {sep}')
|
289 |
logger.info(f'question: {question} {sep}')
|
|
|
290 |
logger.info(f'answer len: {len(response.get_answer())} {sep}')
|
291 |
logger.info(f'answer original: {answer} {sep}')
|
292 |
logger.info(f'answer postprocessed: {response.get_answer()} {sep}')
|
requirements.txt
CHANGED
@@ -26,4 +26,3 @@ InstructorEmbedding==1.0.0
|
|
26 |
faiss_cpu==1.7.3
|
27 |
uvicorn==0.22.0
|
28 |
pytest==7.3.1
|
29 |
-
google-cloud-bigquery==3.17.2
|
|
|
26 |
faiss_cpu==1.7.3
|
27 |
uvicorn==0.22.0
|
28 |
pytest==7.3.1
|
|