Spaces:
Runtime error
Runtime error
rlancemartin
commited on
Commit
•
424d53d
1
Parent(s):
4d476f6
Create app
Browse files- README.md +64 -4
- app.py +483 -0
- img/diagnostic.jpg +0 -0
- requirements.txt +21 -0
- text_utils.py +120 -0
README.md
CHANGED
@@ -1,6 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: Auto Evaluator
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: yellow
|
6 |
sdk: streamlit
|
@@ -8,6 +70,4 @@ sdk_version: 1.19.0
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
# `Auto-evaluator` :brain: :memo:
|
2 |
+
|
3 |
+
This is a lightweight evaluation tool for question-answering using `Langchain` to:
|
4 |
+
|
5 |
+
- Ask the user to input a set of documents of interest
|
6 |
+
|
7 |
+
- Apply an LLM (`GPT-3.5-turbo`) to auto-generate `question`-`answer` pairs from these docs
|
8 |
+
|
9 |
+
- Generate a question-answering chain with a specified set of UI-chosen configurations
|
10 |
+
|
11 |
+
- Use the chain to generate a response to each `question`
|
12 |
+
|
13 |
+
- Use an LLM (`GPT-3.5-turbo`) to score the response relative to the `answer`
|
14 |
+
|
15 |
+
- Explore scoring across various chain configurations
|
16 |
+
|
17 |
+
**Run as Streamlit app**
|
18 |
+
|
19 |
+
`pip install -r requirements.txt`
|
20 |
+
|
21 |
+
`streamlit run auto-evaluator.py`
|
22 |
+
|
23 |
+
**Inputs**
|
24 |
+
|
25 |
+
`num_eval_questions` - Number of questions to auto-generate (if the user does not supply an eval set)
|
26 |
+
|
27 |
+
`split_method` - Method for text splitting
|
28 |
+
|
29 |
+
`chunk_chars` - Chunk size for text splitting
|
30 |
+
|
31 |
+
`overlap` - Chunk overlap for text splitting
|
32 |
+
|
33 |
+
`embeddings` - Embedding method for chunks
|
34 |
+
|
35 |
+
`retriever_type` - Chunk retrieval method
|
36 |
+
|
37 |
+
`num_neighbors` - Neighbors for retrieval
|
38 |
+
|
39 |
+
`model` - LLM for summarization of retrieved chunks
|
40 |
+
|
41 |
+
`grade_prompt` - Prompt choice for model self-grading
|
42 |
+
|
43 |
+
**Blog**
|
44 |
+
|
45 |
+
https://blog.langchain.dev/auto-eval-of-question-answering-tasks/
|
46 |
+
|
47 |
+
**UI**
|
48 |
+
|
49 |
+
![image](https://user-images.githubusercontent.com/122662504/233218347-de10cf41-6230-47a7-aa9e-8ab01673b87a.png)
|
50 |
+
|
51 |
+
**Hosted app**
|
52 |
+
|
53 |
+
See:
|
54 |
+
https://github.com/langchain-ai/auto-evaluator
|
55 |
+
|
56 |
+
And:
|
57 |
+
https://autoevaluator.langchain.com/
|
58 |
+
|
59 |
+
**Disclaimer**
|
60 |
+
|
61 |
+
```You will need an OpenAI API key with access to `GPT-4` and an Anthropic API key to take advantage of all of the default dashboard model settings. However, additional models (e.g., from Hugging Face) can be easily added to the app.```
|
62 |
+
|
63 |
---
|
64 |
title: Auto Evaluator
|
65 |
+
emoji: :brain: :memo:
|
66 |
colorFrom: blue
|
67 |
colorTo: yellow
|
68 |
sdk: streamlit
|
|
|
70 |
app_file: app.py
|
71 |
pinned: false
|
72 |
license: mit
|
73 |
+
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
from typing import List
|
5 |
+
import faiss
|
6 |
+
import pypdf
|
7 |
+
import random
|
8 |
+
import itertools
|
9 |
+
import text_utils
|
10 |
+
import pandas as pd
|
11 |
+
import altair as alt
|
12 |
+
import streamlit as st
|
13 |
+
from io import StringIO
|
14 |
+
from llama_index import Document
|
15 |
+
from langchain.llms import Anthropic
|
16 |
+
from langchain import HuggingFaceHub
|
17 |
+
from langchain.chains import RetrievalQA
|
18 |
+
from langchain.vectorstores import FAISS
|
19 |
+
from llama_index import LangchainEmbedding
|
20 |
+
from langchain.chat_models import ChatOpenAI
|
21 |
+
from langchain.retrievers import SVMRetriever
|
22 |
+
from langchain.chains import QAGenerationChain
|
23 |
+
from langchain.retrievers import TFIDFRetriever
|
24 |
+
from langchain.evaluation.qa import QAEvalChain
|
25 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
26 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
27 |
+
from gpt_index import LLMPredictor, ServiceContext, GPTFaissIndex
|
28 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
|
29 |
+
from text_utils import GRADE_DOCS_PROMPT, GRADE_ANSWER_PROMPT, GRADE_DOCS_PROMPT_FAST, GRADE_ANSWER_PROMPT_FAST, GRADE_ANSWER_PROMPT_BIAS_CHECK, GRADE_ANSWER_PROMPT_OPENAI
|
30 |
+
|
31 |
+
# Keep dataframe in memory to accumulate experimental results
|
32 |
+
if "existing_df" not in st.session_state:
|
33 |
+
summary = pd.DataFrame(columns=['chunk_chars',
|
34 |
+
'overlap',
|
35 |
+
'split',
|
36 |
+
'model',
|
37 |
+
'retriever',
|
38 |
+
'embedding',
|
39 |
+
'num_neighbors',
|
40 |
+
'Latency',
|
41 |
+
'Retrieval score',
|
42 |
+
'Answer score'])
|
43 |
+
st.session_state.existing_df = summary
|
44 |
+
else:
|
45 |
+
summary = st.session_state.existing_df
|
46 |
+
|
47 |
+
|
48 |
+
@st.cache_data
|
49 |
+
def load_docs(files: List) -> str:
|
50 |
+
"""
|
51 |
+
Load docs from files
|
52 |
+
@param files: list of files to load
|
53 |
+
@return: string of all docs concatenated
|
54 |
+
"""
|
55 |
+
|
56 |
+
st.info("`Reading doc ...`")
|
57 |
+
all_text = ""
|
58 |
+
for file_path in files:
|
59 |
+
file_extension = os.path.splitext(file_path.name)[1]
|
60 |
+
if file_extension == ".pdf":
|
61 |
+
pdf_reader = pypdf.PdfReader(file_path)
|
62 |
+
file_content = ""
|
63 |
+
for page in pdf_reader.pages:
|
64 |
+
file_content += page.extract_text()
|
65 |
+
file_content = text_utils.clean_pdf_text(file_content)
|
66 |
+
all_text += file_content
|
67 |
+
elif file_extension == ".txt":
|
68 |
+
stringio = StringIO(file_path.getvalue().decode("utf-8"))
|
69 |
+
file_content = stringio.read()
|
70 |
+
all_text += file_content
|
71 |
+
else:
|
72 |
+
st.warning('Please provide txt or pdf.', icon="⚠️")
|
73 |
+
return all_text
|
74 |
+
|
75 |
+
|
76 |
+
@st.cache_data
|
77 |
+
def generate_eval(text: str, num_questions: int, chunk: int):
|
78 |
+
"""
|
79 |
+
Generate eval set
|
80 |
+
@param text: text to generate eval set from
|
81 |
+
@param num_questions: number of questions to generate
|
82 |
+
@param chunk: chunk size to draw question from in the doc
|
83 |
+
@return: eval set as JSON list
|
84 |
+
"""
|
85 |
+
st.info("`Generating eval set ...`")
|
86 |
+
n = len(text)
|
87 |
+
starting_indices = [random.randint(0, n - chunk) for _ in range(num_questions)]
|
88 |
+
sub_sequences = [text[i:i + chunk] for i in starting_indices]
|
89 |
+
chain = QAGenerationChain.from_llm(ChatOpenAI(temperature=0))
|
90 |
+
eval_set = []
|
91 |
+
for i, b in enumerate(sub_sequences):
|
92 |
+
try:
|
93 |
+
qa = chain.run(b)
|
94 |
+
eval_set.append(qa)
|
95 |
+
except:
|
96 |
+
st.warning('Error generating question %s.' % str(i + 1), icon="⚠️")
|
97 |
+
eval_set_full = list(itertools.chain.from_iterable(eval_set))
|
98 |
+
return eval_set_full
|
99 |
+
|
100 |
+
|
101 |
+
@st.cache_resource
|
102 |
+
def split_texts(text, chunk_size: int, overlap, split_method: str):
|
103 |
+
"""
|
104 |
+
Split text into chunks
|
105 |
+
@param text: text to split
|
106 |
+
@param chunk_size:
|
107 |
+
@param overlap:
|
108 |
+
@param split_method:
|
109 |
+
@return: list of str splits
|
110 |
+
"""
|
111 |
+
st.info("`Splitting doc ...`")
|
112 |
+
if split_method == "RecursiveTextSplitter":
|
113 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
|
114 |
+
chunk_overlap=overlap)
|
115 |
+
elif split_method == "CharacterTextSplitter":
|
116 |
+
text_splitter = CharacterTextSplitter(separator=" ",
|
117 |
+
chunk_size=chunk_size,
|
118 |
+
chunk_overlap=overlap)
|
119 |
+
else:
|
120 |
+
st.warning("`Split method not recognized. Using RecursiveCharacterTextSplitter`", icon="⚠️")
|
121 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
|
122 |
+
chunk_overlap=overlap)
|
123 |
+
|
124 |
+
split_text = text_splitter.split_text(text)
|
125 |
+
return split_text
|
126 |
+
|
127 |
+
|
128 |
+
@st.cache_resource
|
129 |
+
def make_llm(model_version: str):
|
130 |
+
"""
|
131 |
+
Make LLM from model version
|
132 |
+
@param model_version: model_version
|
133 |
+
@return: LLN
|
134 |
+
"""
|
135 |
+
if (model_version == "gpt-3.5-turbo") or (model_version == "gpt-4"):
|
136 |
+
chosen_model = ChatOpenAI(model_name=model_version, temperature=0)
|
137 |
+
elif model_version == "anthropic":
|
138 |
+
chosen_model = Anthropic(temperature=0)
|
139 |
+
elif model_version == "flan-t5-xl":
|
140 |
+
chosen_model = HuggingFaceHub(repo_id="google/flan-t5-xl",model_kwargs={"temperature":0,"max_length":64})
|
141 |
+
else:
|
142 |
+
st.warning("`Model version not recognized. Using gpt-3.5-turbo`", icon="⚠️")
|
143 |
+
chosen_model = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
|
144 |
+
return chosen_model
|
145 |
+
|
146 |
+
@st.cache_resource
|
147 |
+
def make_retriever(splits, retriever_type, embedding_type, num_neighbors, _llm):
|
148 |
+
"""
|
149 |
+
Make document retriever
|
150 |
+
@param splits: list of str splits
|
151 |
+
@param retriever_type: retriever type
|
152 |
+
@param embedding_type: embedding type
|
153 |
+
@param num_neighbors: number of neighbors for retrieval
|
154 |
+
@param _llm: model
|
155 |
+
@return: retriever
|
156 |
+
"""
|
157 |
+
st.info("`Making retriever ...`")
|
158 |
+
# Set embeddings
|
159 |
+
if embedding_type == "OpenAI":
|
160 |
+
embedding = OpenAIEmbeddings()
|
161 |
+
elif embedding_type == "HuggingFace":
|
162 |
+
embedding = HuggingFaceEmbeddings()
|
163 |
+
else:
|
164 |
+
st.warning("`Embedding type not recognized. Using OpenAI`", icon="⚠️")
|
165 |
+
embedding = OpenAIEmbeddings()
|
166 |
+
|
167 |
+
# Select retriever
|
168 |
+
if retriever_type == "similarity-search":
|
169 |
+
try:
|
170 |
+
vector_store = FAISS.from_texts(splits, embedding)
|
171 |
+
except ValueError:
|
172 |
+
st.warning("`Error using OpenAI embeddings (disallowed TikToken token in the text). Using HuggingFace.`",
|
173 |
+
icon="⚠️")
|
174 |
+
vector_store = FAISS.from_texts(splits, HuggingFaceEmbeddings())
|
175 |
+
retriever_obj = vector_store.as_retriever(k=num_neighbors)
|
176 |
+
elif retriever_type == "SVM":
|
177 |
+
retriever_obj = SVMRetriever.from_texts(splits, embedding)
|
178 |
+
elif retriever_type == "TF-IDF":
|
179 |
+
retriever_obj = TFIDFRetriever.from_texts(splits)
|
180 |
+
elif retriever_type == "Llama-Index":
|
181 |
+
documents = [Document(t, LangchainEmbedding(embedding)) for t in splits]
|
182 |
+
llm_predictor = LLMPredictor(llm)
|
183 |
+
context = ServiceContext.from_defaults(chunk_size_limit=512, llm_predictor=llm_predictor)
|
184 |
+
d = 1536
|
185 |
+
faiss_index = faiss.IndexFlatL2(d)
|
186 |
+
retriever_obj = GPTFaissIndex.from_documents(documents, faiss_index=faiss_index, service_context=context)
|
187 |
+
else:
|
188 |
+
st.warning("`Retriever type not recognized. Using SVM`", icon="⚠️")
|
189 |
+
retriever_obj = SVMRetriever.from_texts(splits, embedding)
|
190 |
+
return retriever_obj
|
191 |
+
|
192 |
+
|
193 |
+
def make_chain(llm, retriever, retriever_type: str) -> RetrievalQA:
|
194 |
+
"""
|
195 |
+
Make chain
|
196 |
+
@param llm: model
|
197 |
+
@param retriever: retriever
|
198 |
+
@param retriever_type: retriever type
|
199 |
+
@return: chain (or return retriever for Llama-Index)
|
200 |
+
"""
|
201 |
+
st.info("`Making chain ...`")
|
202 |
+
if retriever_type == "Llama-Index":
|
203 |
+
qa = retriever
|
204 |
+
else:
|
205 |
+
qa = RetrievalQA.from_chain_type(llm,
|
206 |
+
chain_type="stuff",
|
207 |
+
retriever=retriever,
|
208 |
+
input_key="question")
|
209 |
+
return qa
|
210 |
+
|
211 |
+
|
212 |
+
def grade_model_answer(predicted_dataset: List, predictions: List, grade_answer_prompt: str) -> List:
|
213 |
+
"""
|
214 |
+
Grades the distilled answer based on ground truth and model predictions.
|
215 |
+
@param predicted_dataset: A list of dictionaries containing ground truth questions and answers.
|
216 |
+
@param predictions: A list of dictionaries containing model predictions for the questions.
|
217 |
+
@param grade_answer_prompt: The prompt level for the grading. Either "Fast" or "Full".
|
218 |
+
@return: A list of scores for the distilled answers.
|
219 |
+
"""
|
220 |
+
# Grade the distilled answer
|
221 |
+
st.info("`Grading model answer ...`")
|
222 |
+
# Set the grading prompt based on the grade_answer_prompt parameter
|
223 |
+
if grade_answer_prompt == "Fast":
|
224 |
+
prompt = GRADE_ANSWER_PROMPT_FAST
|
225 |
+
elif grade_answer_prompt == "Descriptive w/ bias check":
|
226 |
+
prompt = GRADE_ANSWER_PROMPT_BIAS_CHECK
|
227 |
+
elif grade_answer_prompt == "OpenAI grading prompt":
|
228 |
+
prompt = GRADE_ANSWER_PROMPT_OPENAI
|
229 |
+
else:
|
230 |
+
prompt = GRADE_ANSWER_PROMPT
|
231 |
+
|
232 |
+
# Create an evaluation chain
|
233 |
+
eval_chain = QAEvalChain.from_llm(
|
234 |
+
llm=ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0),
|
235 |
+
prompt=prompt
|
236 |
+
)
|
237 |
+
|
238 |
+
# Evaluate the predictions and ground truth using the evaluation chain
|
239 |
+
graded_outputs = eval_chain.evaluate(
|
240 |
+
predicted_dataset,
|
241 |
+
predictions,
|
242 |
+
question_key="question",
|
243 |
+
prediction_key="result"
|
244 |
+
)
|
245 |
+
|
246 |
+
return graded_outputs
|
247 |
+
|
248 |
+
|
249 |
+
def grade_model_retrieval(gt_dataset: List, predictions: List, grade_docs_prompt: str):
|
250 |
+
"""
|
251 |
+
Grades the relevance of retrieved documents based on ground truth and model predictions.
|
252 |
+
@param gt_dataset: list of dictionaries containing ground truth questions and answers.
|
253 |
+
@param predictions: list of dictionaries containing model predictions for the questions
|
254 |
+
@param grade_docs_prompt: prompt level for the grading. Either "Fast" or "Full"
|
255 |
+
@return: list of scores for the retrieved documents.
|
256 |
+
"""
|
257 |
+
# Grade the docs retrieval
|
258 |
+
st.info("`Grading relevance of retrieved docs ...`")
|
259 |
+
|
260 |
+
# Set the grading prompt based on the grade_docs_prompt parameter
|
261 |
+
prompt = GRADE_DOCS_PROMPT_FAST if grade_docs_prompt == "Fast" else GRADE_DOCS_PROMPT
|
262 |
+
|
263 |
+
# Create an evaluation chain
|
264 |
+
eval_chain = QAEvalChain.from_llm(
|
265 |
+
llm=ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0),
|
266 |
+
prompt=prompt
|
267 |
+
)
|
268 |
+
|
269 |
+
# Evaluate the predictions and ground truth using the evaluation chain
|
270 |
+
graded_outputs = eval_chain.evaluate(
|
271 |
+
gt_dataset,
|
272 |
+
predictions,
|
273 |
+
question_key="question",
|
274 |
+
prediction_key="result"
|
275 |
+
)
|
276 |
+
return graded_outputs
|
277 |
+
|
278 |
+
|
279 |
+
def run_evaluation(chain, retriever, eval_set, grade_prompt, retriever_type, num_neighbors):
|
280 |
+
"""
|
281 |
+
Runs evaluation on a model's performance on a given evaluation dataset.
|
282 |
+
@param chain: Model chain used for answering questions
|
283 |
+
@param retriever: Document retriever used for retrieving relevant documents
|
284 |
+
@param eval_set: List of dictionaries containing questions and corresponding ground truth answers
|
285 |
+
@param grade_prompt: String prompt used for grading model's performance
|
286 |
+
@param retriever_type: String specifying the type of retriever used
|
287 |
+
@param num_neighbors: Number of neighbors to retrieve using the retriever
|
288 |
+
@return: A tuple of four items:
|
289 |
+
- answers_grade: A dictionary containing scores for the model's answers.
|
290 |
+
- retrieval_grade: A dictionary containing scores for the model's document retrieval.
|
291 |
+
- latencies_list: A list of latencies in seconds for each question answered.
|
292 |
+
- predictions_list: A list of dictionaries containing the model's predicted answers and relevant documents for each question.
|
293 |
+
"""
|
294 |
+
st.info("`Running evaluation ...`")
|
295 |
+
predictions_list = []
|
296 |
+
retrieved_docs = []
|
297 |
+
gt_dataset = []
|
298 |
+
latencies_list = []
|
299 |
+
|
300 |
+
for data in eval_set:
|
301 |
+
|
302 |
+
# Get answer and log latency
|
303 |
+
start_time = time.time()
|
304 |
+
if retriever_type != "Llama-Index":
|
305 |
+
predictions_list.append(chain(data))
|
306 |
+
elif retriever_type == "Llama-Index":
|
307 |
+
answer = chain.query(data["question"], similarity_top_k=num_neighbors, response_mode="tree_summarize",
|
308 |
+
use_async=True)
|
309 |
+
predictions_list.append({"question": data["question"], "answer": data["answer"], "result": answer.response})
|
310 |
+
gt_dataset.append(data)
|
311 |
+
end_time = time.time()
|
312 |
+
elapsed_time = end_time - start_time
|
313 |
+
latencies_list.append(elapsed_time)
|
314 |
+
|
315 |
+
# Retrieve docs
|
316 |
+
retrieved_doc_text = ""
|
317 |
+
if retriever_type == "Llama-Index":
|
318 |
+
for i, doc in enumerate(answer.source_nodes):
|
319 |
+
retrieved_doc_text += "Doc %s: " % str(i + 1) + doc.node.text + " "
|
320 |
+
|
321 |
+
else:
|
322 |
+
docs = retriever.get_relevant_documents(data["question"])
|
323 |
+
for i, doc in enumerate(docs):
|
324 |
+
retrieved_doc_text += "Doc %s: " % str(i + 1) + doc.page_content + " "
|
325 |
+
|
326 |
+
retrieved = {"question": data["question"], "answer": data["answer"], "result": retrieved_doc_text}
|
327 |
+
retrieved_docs.append(retrieved)
|
328 |
+
|
329 |
+
# Grade
|
330 |
+
answers_grade = grade_model_answer(gt_dataset, predictions_list, grade_prompt)
|
331 |
+
retrieval_grade = grade_model_retrieval(gt_dataset, retrieved_docs, grade_prompt)
|
332 |
+
return answers_grade, retrieval_grade, latencies_list, predictions_list
|
333 |
+
|
334 |
+
|
335 |
+
# Auth
|
336 |
+
st.sidebar.image("img/diagnostic.jpg")
|
337 |
+
|
338 |
+
with st.sidebar.form("user_input"):
|
339 |
+
|
340 |
+
oai_api_key = st.sidebar.text_input("`OpenAI API Key:`", type="password")
|
341 |
+
os.environ["OPENAI_API_KEY"] = oai_api_key
|
342 |
+
ant_api_key = st.sidebar.text_input("`(Optional) Anthropic API Key:`", type="password")
|
343 |
+
os.environ["ANTHROPIC_API_KEY"] = ant_api_key
|
344 |
+
|
345 |
+
num_eval_questions = st.select_slider("`Number of eval questions`",
|
346 |
+
options=[1, 5, 10, 15, 20], value=5)
|
347 |
+
|
348 |
+
chunk_chars = st.select_slider("`Choose chunk size for splitting`",
|
349 |
+
options=[500, 750, 1000, 1500, 2000], value=1000)
|
350 |
+
|
351 |
+
overlap = st.select_slider("`Choose overlap for splitting`",
|
352 |
+
options=[0, 50, 100, 150, 200], value=100)
|
353 |
+
|
354 |
+
split_method = st.radio("`Split method`",
|
355 |
+
("RecursiveTextSplitter",
|
356 |
+
"CharacterTextSplitter"),
|
357 |
+
index=0)
|
358 |
+
|
359 |
+
model = st.radio("`Choose model`",
|
360 |
+
("gpt-3.5-turbo",
|
361 |
+
"gpt-4",
|
362 |
+
"anthropic",
|
363 |
+
"flan-t5-xl"),
|
364 |
+
index=0)
|
365 |
+
|
366 |
+
retriever_type = st.radio("`Choose retriever`",
|
367 |
+
("TF-IDF",
|
368 |
+
"SVM",
|
369 |
+
"Llama-Index",
|
370 |
+
"similarity-search"),
|
371 |
+
index=3)
|
372 |
+
|
373 |
+
num_neighbors = st.select_slider("`Choose # chunks to retrieve`",
|
374 |
+
options=[3, 4, 5, 6, 7, 8])
|
375 |
+
|
376 |
+
embeddings = st.radio("`Choose embeddings`",
|
377 |
+
("HuggingFace",
|
378 |
+
"OpenAI"),
|
379 |
+
index=1)
|
380 |
+
|
381 |
+
grade_prompt = st.radio("`Grading style prompt`",
|
382 |
+
("Fast",
|
383 |
+
"Descriptive",
|
384 |
+
"Descriptive w/ bias check",
|
385 |
+
"OpenAI grading prompt"),
|
386 |
+
index=0)
|
387 |
+
|
388 |
+
submitted = st.form_submit_button("Submit evaluation")
|
389 |
+
|
390 |
+
st.sidebar.write("`By:` [@RLanceMartin](https://twitter.com/RLanceMartin)")
|
391 |
+
|
392 |
+
# App
|
393 |
+
st.header("`Auto-evaluator`")
|
394 |
+
st.info(
|
395 |
+
"`I am an evaluation tool for question-answering. Given documents, I will auto-generate a question-answer eval "
|
396 |
+
"set and evaluate using the selected chain settings. Experiments with different configurations are logged. "
|
397 |
+
"Optionally, provide your own eval set (as a JSON, see docs/karpathy-pod-eval.json for an example).`")
|
398 |
+
|
399 |
+
with st.form(key='file_inputs'):
|
400 |
+
uploaded_file = st.file_uploader("`Please upload a file to evaluate (.txt or .pdf):` ",
|
401 |
+
type=['pdf', 'txt'],
|
402 |
+
accept_multiple_files=True)
|
403 |
+
|
404 |
+
uploaded_eval_set = st.file_uploader("`[Optional] Please upload eval set (.json):` ",
|
405 |
+
type=['json'],
|
406 |
+
accept_multiple_files=False)
|
407 |
+
|
408 |
+
submitted = st.form_submit_button("Submit files")
|
409 |
+
|
410 |
+
if uploaded_file:
|
411 |
+
|
412 |
+
# Load docs
|
413 |
+
text = load_docs(uploaded_file)
|
414 |
+
# Generate num_eval_questions questions, each from context of 3k chars randomly selected
|
415 |
+
if not uploaded_eval_set:
|
416 |
+
eval_set = generate_eval(text, num_eval_questions, 3000)
|
417 |
+
else:
|
418 |
+
eval_set = json.loads(uploaded_eval_set.read())
|
419 |
+
# Split text
|
420 |
+
splits = split_texts(text, chunk_chars, overlap, split_method)
|
421 |
+
# Make LLM
|
422 |
+
llm = make_llm(model)
|
423 |
+
# Make vector DB
|
424 |
+
retriever = make_retriever(splits, retriever_type, embeddings, num_neighbors, llm)
|
425 |
+
# Make chain
|
426 |
+
qa_chain = make_chain(llm, retriever, retriever_type)
|
427 |
+
# Grade model
|
428 |
+
graded_answers, graded_retrieval, latency, predictions = run_evaluation(qa_chain, retriever, eval_set, grade_prompt,
|
429 |
+
retriever_type, num_neighbors)
|
430 |
+
|
431 |
+
# Assemble outputs
|
432 |
+
d = pd.DataFrame(predictions)
|
433 |
+
d['answer score'] = [g['text'] for g in graded_answers]
|
434 |
+
d['docs score'] = [g['text'] for g in graded_retrieval]
|
435 |
+
d['latency'] = latency
|
436 |
+
|
437 |
+
# Summary statistics
|
438 |
+
mean_latency = d['latency'].mean()
|
439 |
+
correct_answer_count = len([text for text in d['answer score'] if "INCORRECT" not in text])
|
440 |
+
correct_docs_count = len([text for text in d['docs score'] if "Context is relevant: True" in text])
|
441 |
+
percentage_answer = (correct_answer_count / len(graded_answers)) * 100
|
442 |
+
percentage_docs = (correct_docs_count / len(graded_retrieval)) * 100
|
443 |
+
|
444 |
+
st.subheader("`Run Results`")
|
445 |
+
st.info(
|
446 |
+
"`I will grade the chain based on: 1/ the relevance of the retrived documents relative to the question and 2/ "
|
447 |
+
"the summarized answer relative to the ground truth answer. You can see (and change) to prompts used for "
|
448 |
+
"grading in text_utils`")
|
449 |
+
st.dataframe(data=d, use_container_width=True)
|
450 |
+
|
451 |
+
# Accumulate results
|
452 |
+
st.subheader("`Aggregate Results`")
|
453 |
+
st.info(
|
454 |
+
"`Retrieval and answer scores are percentage of retrived documents deemed relevant by the LLM grader ("
|
455 |
+
"relative to the question) and percentage of summarized answers deemed relevant (relative to ground truth "
|
456 |
+
"answer), respectively. The size of point correponds to the latency (in seconds) of retrieval + answer "
|
457 |
+
"summarization (larger circle = slower).`")
|
458 |
+
new_row = pd.DataFrame({'chunk_chars': [chunk_chars],
|
459 |
+
'overlap': [overlap],
|
460 |
+
'split': [split_method],
|
461 |
+
'model': [model],
|
462 |
+
'retriever': [retriever_type],
|
463 |
+
'embedding': [embeddings],
|
464 |
+
'num_neighbors': [num_neighbors],
|
465 |
+
'Latency': [mean_latency],
|
466 |
+
'Retrieval score': [percentage_docs],
|
467 |
+
'Answer score': [percentage_answer]})
|
468 |
+
summary = pd.concat([summary, new_row], ignore_index=True)
|
469 |
+
st.dataframe(data=summary, use_container_width=True)
|
470 |
+
st.session_state.existing_df = summary
|
471 |
+
|
472 |
+
# Dataframe for visualization
|
473 |
+
show = summary.reset_index().copy()
|
474 |
+
show.columns = ['expt number', 'chunk_chars', 'overlap',
|
475 |
+
'split', 'model', 'retriever', 'embedding', 'num_neighbors', 'Latency', 'Retrieval score',
|
476 |
+
'Answer score']
|
477 |
+
show['expt number'] = show['expt number'].apply(lambda x: "Expt #: " + str(x + 1))
|
478 |
+
c = alt.Chart(show).mark_circle().encode(x='Retrieval score',
|
479 |
+
y='Answer score',
|
480 |
+
size=alt.Size('Latency'),
|
481 |
+
color='expt number',
|
482 |
+
tooltip=['expt number', 'Retrieval score', 'Latency', 'Answer score'])
|
483 |
+
st.altair_chart(c, use_container_width=True, theme="streamlit")
|
img/diagnostic.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas==1.4.3
|
2 |
+
fastapi==0.85.2
|
3 |
+
langchain==0.0.141
|
4 |
+
python-multipart==0.0.6
|
5 |
+
uvicorn==0.18.3
|
6 |
+
openai==0.27.0
|
7 |
+
tiktoken==0.3.1
|
8 |
+
faiss-cpu==1.7.3
|
9 |
+
huggingface-hub==0.12.0
|
10 |
+
anthropic==0.2.6
|
11 |
+
llama-cpp-python==0.1.32
|
12 |
+
pypdf==3.7.1
|
13 |
+
filetype==1.2.0
|
14 |
+
altair==4.2.2
|
15 |
+
tokenizers==0.13.3
|
16 |
+
sentence-transformers==2.2.2
|
17 |
+
scikit-learn==1.2.1
|
18 |
+
llama-index==0.4.35.post1
|
19 |
+
streamlit==1.21.0
|
20 |
+
gpt-index==0.5.16
|
21 |
+
faiss-cpu==1.7.3
|
text_utils.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from langchain.prompts import PromptTemplate
|
3 |
+
|
4 |
+
|
5 |
+
def clean_pdf_text(text: str) -> str:
|
6 |
+
"""Cleans text extracted from a PDF file."""
|
7 |
+
# TODO: Remove References/Bibliography section.
|
8 |
+
return remove_citations(text)
|
9 |
+
|
10 |
+
|
11 |
+
def remove_citations(text: str) -> str:
|
12 |
+
"""Removes in-text citations from a string."""
|
13 |
+
# (Author, Year)
|
14 |
+
text = re.sub(r'\([A-Za-z0-9,.\s]+\s\d{4}\)', '', text)
|
15 |
+
# [1], [2], [3-5], [3, 33, 49, 51]
|
16 |
+
text = re.sub(r'\[[0-9,-]+(,\s[0-9,-]+)*\]', '', text)
|
17 |
+
return text
|
18 |
+
|
19 |
+
|
20 |
+
template = """You are a teacher grading a quiz.
|
21 |
+
You are given a question, the student's answer, and the true answer, and are asked to score the student answer as either CORRECT or INCORRECT.
|
22 |
+
|
23 |
+
Example Format:
|
24 |
+
QUESTION: question here
|
25 |
+
STUDENT ANSWER: student's answer here
|
26 |
+
TRUE ANSWER: true answer here
|
27 |
+
GRADE: CORRECT or INCORRECT here
|
28 |
+
|
29 |
+
Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin!
|
30 |
+
|
31 |
+
QUESTION: {query}
|
32 |
+
STUDENT ANSWER: {result}
|
33 |
+
TRUE ANSWER: {answer}
|
34 |
+
GRADE:
|
35 |
+
|
36 |
+
And explain why the STUDENT ANSWER is correct or incorrect.
|
37 |
+
"""
|
38 |
+
|
39 |
+
GRADE_ANSWER_PROMPT = PromptTemplate(input_variables=["query", "result", "answer"], template=template)
|
40 |
+
|
41 |
+
template = """You are a teacher grading a quiz.
|
42 |
+
You are given a question, the student's answer, and the true answer, and are asked to score the student answer as either CORRECT or INCORRECT.
|
43 |
+
You are also asked to identify potential sources of bias in the question and in the true answer.
|
44 |
+
|
45 |
+
Example Format:
|
46 |
+
QUESTION: question here
|
47 |
+
STUDENT ANSWER: student's answer here
|
48 |
+
TRUE ANSWER: true answer here
|
49 |
+
GRADE: CORRECT or INCORRECT here
|
50 |
+
|
51 |
+
Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin!
|
52 |
+
|
53 |
+
QUESTION: {query}
|
54 |
+
STUDENT ANSWER: {result}
|
55 |
+
TRUE ANSWER: {answer}
|
56 |
+
GRADE:
|
57 |
+
|
58 |
+
And explain why the STUDENT ANSWER is correct or incorrect, identify potential sources of bias in the QUESTION, and identify potential sources of bias in the TRUE ANSWER.
|
59 |
+
"""
|
60 |
+
|
61 |
+
GRADE_ANSWER_PROMPT_BIAS_CHECK = PromptTemplate(input_variables=["query", "result", "answer"], template=template)
|
62 |
+
|
63 |
+
template = """You are assessing a submitted student answer to a question relative to the true answer based on the provided criteria:
|
64 |
+
|
65 |
+
***
|
66 |
+
QUESTION: {query}
|
67 |
+
***
|
68 |
+
STUDENT ANSWER: {result}
|
69 |
+
***
|
70 |
+
TRUE ANSWER: {answer}
|
71 |
+
***
|
72 |
+
Criteria:
|
73 |
+
relevance: Is the submission referring to a real quote from the text?"
|
74 |
+
conciseness: Is the answer concise and to the point?"
|
75 |
+
correct: Is the answer correct?"
|
76 |
+
***
|
77 |
+
Does the submission meet the criterion? First, write out in a step by step manner your reasoning about the criterion to be sure that your conclusion is correct. Avoid simply stating the correct answers at the outset. Then print the "CORRECT" or "INCORRECT" (without quotes or punctuation) on its own line corresponding to the correct answer.
|
78 |
+
Reasoning:
|
79 |
+
"""
|
80 |
+
|
81 |
+
GRADE_ANSWER_PROMPT_OPENAI = PromptTemplate(input_variables=["query", "result", "answer"], template=template)
|
82 |
+
|
83 |
+
template = """You are a teacher grading a quiz.
|
84 |
+
You are given a question, the student's answer, and the true answer, and are asked to score the student answer as either CORRECT or INCORRECT.
|
85 |
+
|
86 |
+
Example Format:
|
87 |
+
QUESTION: question here
|
88 |
+
STUDENT ANSWER: student's answer here
|
89 |
+
TRUE ANSWER: true answer here
|
90 |
+
GRADE: CORRECT or INCORRECT here
|
91 |
+
|
92 |
+
Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin!
|
93 |
+
|
94 |
+
QUESTION: {query}
|
95 |
+
STUDENT ANSWER: {result}
|
96 |
+
TRUE ANSWER: {answer}
|
97 |
+
GRADE:"""
|
98 |
+
|
99 |
+
GRADE_ANSWER_PROMPT_FAST = PromptTemplate(input_variables=["query", "result", "answer"], template=template)
|
100 |
+
|
101 |
+
template = """
|
102 |
+
Given the question: \n
|
103 |
+
{query}
|
104 |
+
Decide if the following retrieved context is relevant: \n
|
105 |
+
{result}
|
106 |
+
Answer in the following format: \n
|
107 |
+
"Context is relevant: True or False." \n
|
108 |
+
And explain why it supports or does not support the correct answer: {answer}"""
|
109 |
+
|
110 |
+
GRADE_DOCS_PROMPT = PromptTemplate(input_variables=["query", "result", "answer"], template=template)
|
111 |
+
|
112 |
+
template = """
|
113 |
+
Given the question: \n
|
114 |
+
{query}
|
115 |
+
Decide if the following retrieved context is relevant to the {answer}: \n
|
116 |
+
{result}
|
117 |
+
Answer in the following format: \n
|
118 |
+
"Context is relevant: True or False." \n """
|
119 |
+
|
120 |
+
GRADE_DOCS_PROMPT_FAST = PromptTemplate(input_variables=["query", "result", "answer"], template=template)
|