Spaces:
Sleeping
Sleeping
File size: 9,135 Bytes
e6c245e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
import streamlit as st
import pymupdf4llm
import pandas as pd
from groq import Groq
import json
import tempfile
from sklearn.model_selection import train_test_split
# Initialize session state variables if they don't exist
if 'train_df' not in st.session_state:
st.session_state.train_df = None
if 'val_df' not in st.session_state:
st.session_state.val_df = None
if 'generated' not in st.session_state:
st.session_state.generated = False
if 'previous_upload_state' not in st.session_state:
st.session_state.previous_upload_state = False
def reset_session_state():
"""Reset all relevant session state variables"""
st.session_state.train_df = None
st.session_state.val_df = None
st.session_state.generated = False
def parse_pdf(uploaded_file) -> str:
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(uploaded_file.getvalue())
tmp_file.seek(0)
text = pymupdf4llm.to_markdown(tmp_file.name)
return text
def generate_qa_pairs(text: str, api_key: str, model: str, num_pairs: int, context: str) -> pd.DataFrame:
client = Groq(api_key=api_key)
prompt = f"""
Given the following text, generate {num_pairs} question-answer pairs:
{text}
Format each pair as:
Q: [Question]
A: [Answer]
Ensure the questions are diverse and cover different aspects of the text.
"""
try:
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant that generates question-answer pairs based on given text."},
{"role": "user", "content": prompt}
]
)
qa_text = response.choices[0].message.content
qa_pairs = []
for pair in qa_text.split('\n\n'):
if pair.startswith('Q:') and 'A:' in pair:
question, answer = pair.split('A:')
question = question.replace('Q:', '').strip()
answer = answer.strip()
qa_pairs.append({
'Question': question,
'Answer': answer,
'Context': context
})
return pd.DataFrame(qa_pairs)
except Exception as e:
st.error(f"Error generating QA pairs: {str(e)}")
return pd.DataFrame()
def create_jsonl_content(df: pd.DataFrame, system_content: str) -> str:
"""Convert DataFrame to JSONL string content"""
jsonl_content = []
for _, row in df.iterrows():
entry = {
"messages": [
{"role": "system", "content": system_content},
{"role": "user", "content": row['Question']},
{"role": "assistant", "content": row['Answer']}
]
}
jsonl_content.append(json.dumps(entry, ensure_ascii=False))
return '\n'.join(jsonl_content)
def process_and_split_data(text: str, api_key: str, model: str, num_pairs: int, context: str, train_size: float):
"""Process data and store results in session state"""
df = generate_qa_pairs(text, api_key, model, num_pairs, context)
if not df.empty:
# Split the dataset
train_df, val_df = train_test_split(
df,
train_size=train_size/100,
random_state=42
)
# Store in session state
st.session_state.train_df = train_df
st.session_state.val_df = val_df
st.session_state.generated = True
return True
return False
def main():
st.title("LLM Dataset Generator")
st.write("Upload a PDF file and generate training & validation sets of question-answer pairs of your data using LLM.")
# Sidebar configurations
st.sidebar.header("Configuration")
api_key = st.sidebar.text_input("Enter Groq API Key", type="password")
model = st.sidebar.selectbox(
"Select Model",
["llama3-8b-8192", "llama3-70b-8192", "mixtral-8x7b-32768", "gemma2-9b-it"]
)
num_pairs = st.sidebar.number_input(
"Number of QA Pairs",
min_value=1,
max_value=10000,
value=5
)
context = st.sidebar.text_area(
"Custom Context",
value="Write a response that appropriately completes the request.",
help="This text will be added to the Context column for each QA pair.",
placeholder= "Add custom context here."
)
# Dataset split configuration
st.sidebar.header("Dataset Split")
train_size = st.sidebar.slider(
"Training Set Size (%)",
min_value=50,
max_value=90,
value=80,
step=5
)
# Output format configuration
st.sidebar.header("Output Format")
output_format = st.sidebar.selectbox(
"Select Output Format",
["CSV", "JSONL"]
)
if output_format == "JSONL":
system_content = st.sidebar.text_area(
"System Message",
value="You are a helpful assistant that provides accurate and informative answers.",
help="This message will be used as the system content in the JSONL format."
)
# Main area
uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
# Check if upload state has changed
current_upload_state = uploaded_file is not None
if current_upload_state != st.session_state.previous_upload_state:
if not current_upload_state: # File was removed
reset_session_state()
st.session_state.previous_upload_state = current_upload_state
if uploaded_file is not None:
if not api_key:
st.warning("Please enter your Groq API key in the sidebar.")
return
text = parse_pdf(uploaded_file)
st.success("PDF processed successfully!")
if st.button("Generate QA Pairs"):
with st.spinner("Generating QA pairs..."):
success = process_and_split_data(text, api_key, model, num_pairs, context, train_size)
if success:
st.success("QA pairs generated successfully!")
# Display results if data has been generated
if st.session_state.generated and st.session_state.train_df is not None and st.session_state.val_df is not None:
# Display the dataframes
st.subheader("Training Set")
st.dataframe(st.session_state.train_df)
st.subheader("Validation Set")
st.dataframe(st.session_state.val_df)
# Create download section
st.subheader("Download Generated Datasets")
col1, col2 = st.columns(2)
with col1:
st.markdown("##### Training Set")
if output_format == "CSV":
train_csv = st.session_state.train_df.to_csv(index=False)
st.download_button(
label="Download Training Set (CSV)",
data=train_csv,
file_name="train_qa_pairs.csv",
mime="text/csv",
key="train_csv"
)
else: # JSONL format
train_jsonl = create_jsonl_content(st.session_state.train_df, system_content)
st.download_button(
label="Download Training Set (JSONL)",
data=train_jsonl,
file_name="train_qa_pairs.jsonl",
mime="application/jsonl",
key="train_jsonl"
)
with col2:
st.markdown("##### Validation Set")
if output_format == "CSV":
val_csv = st.session_state.val_df.to_csv(index=False)
st.download_button(
label="Download Validation Set (CSV)",
data=val_csv,
file_name="val_qa_pairs.csv",
mime="text/csv",
key="val_csv"
)
else: # JSONL format
val_jsonl = create_jsonl_content(st.session_state.val_df, system_content)
st.download_button(
label="Download Validation Set (JSONL)",
data=val_jsonl,
file_name="val_qa_pairs.jsonl",
mime="application/jsonl",
key="val_jsonl"
)
# Display statistics
st.subheader("Statistics")
st.write(f"Total QA pairs: {len(st.session_state.train_df) + len(st.session_state.val_df)}")
st.write(f"Training set size: {len(st.session_state.train_df)} ({train_size}%)")
st.write(f"Validation set size: {len(st.session_state.val_df)} ({100-train_size}%)")
st.write(f"Average question length: {st.session_state.train_df['Question'].str.len().mean():.1f} characters")
st.write(f"Average answer length: {st.session_state.train_df['Answer'].str.len().mean():.1f} characters")
if __name__ == "__main__":
st.set_page_config(
page_title="LLM Dataset Generator",
page_icon="📚",
layout="wide"
)
main() |