Spaces:
Runtime error
Runtime error
KonradSzafer
commited on
Commit
•
b7068fd
1
Parent(s):
cf57696
question and answer postprocessing
Browse files- benchmark/__main__.py +1 -0
- qa_engine/qa_engine.py +30 -1
benchmark/__main__.py
CHANGED
@@ -33,6 +33,7 @@ def main():
|
|
33 |
|
34 |
wandb.init(
|
35 |
project='HF-Docs-QA',
|
|
|
36 |
name=f'{config.question_answering_model_id} - {config.embedding_model_id} - {config.index_repo_id}',
|
37 |
mode='run', # run/disabled
|
38 |
config=filtered_config
|
|
|
33 |
|
34 |
wandb.init(
|
35 |
project='HF-Docs-QA',
|
36 |
+
entity='hf-qa-bot',
|
37 |
name=f'{config.question_answering_model_id} - {config.embedding_model_id} - {config.index_repo_id}',
|
38 |
mode='run', # run/disabled
|
39 |
config=filtered_config
|
qa_engine/qa_engine.py
CHANGED
@@ -228,6 +228,33 @@ class QAEngine():
|
|
228 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
229 |
|
230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
def get_response(self, question: str, messages_context: str = '') -> Response:
|
232 |
"""
|
233 |
Generate an answer to the specified question.
|
@@ -271,7 +298,9 @@ class QAEngine():
|
|
271 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
272 |
|
273 |
logger.info('Running LLM chain')
|
274 |
-
|
|
|
|
|
275 |
response.set_answer(answer)
|
276 |
logger.info('Received answer')
|
277 |
|
|
|
228 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
229 |
|
230 |
|
231 |
+
@staticmethod
|
232 |
+
def _preprocess_question(question: str) -> str:
|
233 |
+
if question[-1] != '?':
|
234 |
+
question += '?'
|
235 |
+
return question
|
236 |
+
|
237 |
+
|
238 |
+
@staticmethod
|
239 |
+
def _postprocess_answer(answer: str) -> str:
|
240 |
+
'''
|
241 |
+
Preprocess the answer by removing unnecessary sequences and stop sequences.
|
242 |
+
'''
|
243 |
+
REMOVE_SEQUENCES = [
|
244 |
+
'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'
|
245 |
+
]
|
246 |
+
STOP_SEQUENCES = [
|
247 |
+
'\nUser:', '\nYou:'
|
248 |
+
]
|
249 |
+
for seq in REMOVE_SEQUENCES:
|
250 |
+
answer = answer.replace(seq, '')
|
251 |
+
for seq in STOP_SEQUENCES:
|
252 |
+
if seq in answer:
|
253 |
+
answer = answer[:answer.index(seq)]
|
254 |
+
answer = answer.strip()
|
255 |
+
return answer
|
256 |
+
|
257 |
+
|
258 |
def get_response(self, question: str, messages_context: str = '') -> Response:
|
259 |
"""
|
260 |
Generate an answer to the specified question.
|
|
|
298 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
299 |
|
300 |
logger.info('Running LLM chain')
|
301 |
+
question_processed = QAEngine._preprocess_question(question)
|
302 |
+
answer = self.llm_chain.run(question=question_processed, context=context)
|
303 |
+
answer = QAEngine._postprocess_answer(answer)
|
304 |
response.set_answer(answer)
|
305 |
logger.info('Received answer')
|
306 |
|