abhinavsarkar commited on
Commit
e6c245e
β€’
1 Parent(s): ee63417

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -0
app.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pymupdf4llm
3
+ import pandas as pd
4
+ from groq import Groq
5
+ import json
6
+ import tempfile
7
+ from sklearn.model_selection import train_test_split
8
+
9
+ # Initialize session state variables if they don't exist
10
+ if 'train_df' not in st.session_state:
11
+ st.session_state.train_df = None
12
+ if 'val_df' not in st.session_state:
13
+ st.session_state.val_df = None
14
+ if 'generated' not in st.session_state:
15
+ st.session_state.generated = False
16
+ if 'previous_upload_state' not in st.session_state:
17
+ st.session_state.previous_upload_state = False
18
+
19
+ def reset_session_state():
20
+ """Reset all relevant session state variables"""
21
+ st.session_state.train_df = None
22
+ st.session_state.val_df = None
23
+ st.session_state.generated = False
24
+
25
+ def parse_pdf(uploaded_file) -> str:
26
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
27
+ tmp_file.write(uploaded_file.getvalue())
28
+ tmp_file.seek(0)
29
+
30
+ text = pymupdf4llm.to_markdown(tmp_file.name)
31
+
32
+ return text
33
+
34
+ def generate_qa_pairs(text: str, api_key: str, model: str, num_pairs: int, context: str) -> pd.DataFrame:
35
+ client = Groq(api_key=api_key)
36
+
37
+ prompt = f"""
38
+ Given the following text, generate {num_pairs} question-answer pairs:
39
+
40
+ {text}
41
+
42
+ Format each pair as:
43
+ Q: [Question]
44
+ A: [Answer]
45
+
46
+ Ensure the questions are diverse and cover different aspects of the text.
47
+ """
48
+
49
+ try:
50
+ response = client.chat.completions.create(
51
+ model=model,
52
+ messages=[
53
+ {"role": "system", "content": "You are a helpful assistant that generates question-answer pairs based on given text."},
54
+ {"role": "user", "content": prompt}
55
+ ]
56
+ )
57
+
58
+ qa_text = response.choices[0].message.content
59
+ qa_pairs = []
60
+
61
+ for pair in qa_text.split('\n\n'):
62
+ if pair.startswith('Q:') and 'A:' in pair:
63
+ question, answer = pair.split('A:')
64
+ question = question.replace('Q:', '').strip()
65
+ answer = answer.strip()
66
+ qa_pairs.append({
67
+ 'Question': question,
68
+ 'Answer': answer,
69
+ 'Context': context
70
+ })
71
+
72
+ return pd.DataFrame(qa_pairs)
73
+
74
+ except Exception as e:
75
+ st.error(f"Error generating QA pairs: {str(e)}")
76
+ return pd.DataFrame()
77
+
78
+ def create_jsonl_content(df: pd.DataFrame, system_content: str) -> str:
79
+ """Convert DataFrame to JSONL string content"""
80
+ jsonl_content = []
81
+ for _, row in df.iterrows():
82
+ entry = {
83
+ "messages": [
84
+ {"role": "system", "content": system_content},
85
+ {"role": "user", "content": row['Question']},
86
+ {"role": "assistant", "content": row['Answer']}
87
+ ]
88
+ }
89
+ jsonl_content.append(json.dumps(entry, ensure_ascii=False))
90
+ return '\n'.join(jsonl_content)
91
+
92
+ def process_and_split_data(text: str, api_key: str, model: str, num_pairs: int, context: str, train_size: float):
93
+ """Process data and store results in session state"""
94
+ df = generate_qa_pairs(text, api_key, model, num_pairs, context)
95
+
96
+ if not df.empty:
97
+ # Split the dataset
98
+ train_df, val_df = train_test_split(
99
+ df,
100
+ train_size=train_size/100,
101
+ random_state=42
102
+ )
103
+
104
+ # Store in session state
105
+ st.session_state.train_df = train_df
106
+ st.session_state.val_df = val_df
107
+ st.session_state.generated = True
108
+ return True
109
+ return False
110
+
111
+ def main():
112
+ st.title("LLM Dataset Generator")
113
+ st.write("Upload a PDF file and generate training & validation sets of question-answer pairs of your data using LLM.")
114
+
115
+ # Sidebar configurations
116
+ st.sidebar.header("Configuration")
117
+
118
+ api_key = st.sidebar.text_input("Enter Groq API Key", type="password")
119
+
120
+ model = st.sidebar.selectbox(
121
+ "Select Model",
122
+ ["llama3-8b-8192", "llama3-70b-8192", "mixtral-8x7b-32768", "gemma2-9b-it"]
123
+ )
124
+
125
+ num_pairs = st.sidebar.number_input(
126
+ "Number of QA Pairs",
127
+ min_value=1,
128
+ max_value=10000,
129
+ value=5
130
+ )
131
+
132
+ context = st.sidebar.text_area(
133
+ "Custom Context",
134
+ value="Write a response that appropriately completes the request.",
135
+ help="This text will be added to the Context column for each QA pair.",
136
+ placeholder= "Add custom context here."
137
+ )
138
+
139
+ # Dataset split configuration
140
+ st.sidebar.header("Dataset Split")
141
+ train_size = st.sidebar.slider(
142
+ "Training Set Size (%)",
143
+ min_value=50,
144
+ max_value=90,
145
+ value=80,
146
+ step=5
147
+ )
148
+
149
+ # Output format configuration
150
+ st.sidebar.header("Output Format")
151
+ output_format = st.sidebar.selectbox(
152
+ "Select Output Format",
153
+ ["CSV", "JSONL"]
154
+ )
155
+
156
+ if output_format == "JSONL":
157
+ system_content = st.sidebar.text_area(
158
+ "System Message",
159
+ value="You are a helpful assistant that provides accurate and informative answers.",
160
+ help="This message will be used as the system content in the JSONL format."
161
+ )
162
+
163
+ # Main area
164
+ uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
165
+
166
+ # Check if upload state has changed
167
+ current_upload_state = uploaded_file is not None
168
+ if current_upload_state != st.session_state.previous_upload_state:
169
+ if not current_upload_state: # File was removed
170
+ reset_session_state()
171
+ st.session_state.previous_upload_state = current_upload_state
172
+
173
+ if uploaded_file is not None:
174
+ if not api_key:
175
+ st.warning("Please enter your Groq API key in the sidebar.")
176
+ return
177
+
178
+ text = parse_pdf(uploaded_file)
179
+ st.success("PDF processed successfully!")
180
+
181
+ if st.button("Generate QA Pairs"):
182
+ with st.spinner("Generating QA pairs..."):
183
+ success = process_and_split_data(text, api_key, model, num_pairs, context, train_size)
184
+ if success:
185
+ st.success("QA pairs generated successfully!")
186
+
187
+ # Display results if data has been generated
188
+ if st.session_state.generated and st.session_state.train_df is not None and st.session_state.val_df is not None:
189
+ # Display the dataframes
190
+ st.subheader("Training Set")
191
+ st.dataframe(st.session_state.train_df)
192
+
193
+ st.subheader("Validation Set")
194
+ st.dataframe(st.session_state.val_df)
195
+
196
+ # Create download section
197
+ st.subheader("Download Generated Datasets")
198
+ col1, col2 = st.columns(2)
199
+
200
+ with col1:
201
+ st.markdown("##### Training Set")
202
+ if output_format == "CSV":
203
+ train_csv = st.session_state.train_df.to_csv(index=False)
204
+ st.download_button(
205
+ label="Download Training Set (CSV)",
206
+ data=train_csv,
207
+ file_name="train_qa_pairs.csv",
208
+ mime="text/csv",
209
+ key="train_csv"
210
+ )
211
+ else: # JSONL format
212
+ train_jsonl = create_jsonl_content(st.session_state.train_df, system_content)
213
+ st.download_button(
214
+ label="Download Training Set (JSONL)",
215
+ data=train_jsonl,
216
+ file_name="train_qa_pairs.jsonl",
217
+ mime="application/jsonl",
218
+ key="train_jsonl"
219
+ )
220
+
221
+ with col2:
222
+ st.markdown("##### Validation Set")
223
+ if output_format == "CSV":
224
+ val_csv = st.session_state.val_df.to_csv(index=False)
225
+ st.download_button(
226
+ label="Download Validation Set (CSV)",
227
+ data=val_csv,
228
+ file_name="val_qa_pairs.csv",
229
+ mime="text/csv",
230
+ key="val_csv"
231
+ )
232
+ else: # JSONL format
233
+ val_jsonl = create_jsonl_content(st.session_state.val_df, system_content)
234
+ st.download_button(
235
+ label="Download Validation Set (JSONL)",
236
+ data=val_jsonl,
237
+ file_name="val_qa_pairs.jsonl",
238
+ mime="application/jsonl",
239
+ key="val_jsonl"
240
+ )
241
+
242
+ # Display statistics
243
+ st.subheader("Statistics")
244
+ st.write(f"Total QA pairs: {len(st.session_state.train_df) + len(st.session_state.val_df)}")
245
+ st.write(f"Training set size: {len(st.session_state.train_df)} ({train_size}%)")
246
+ st.write(f"Validation set size: {len(st.session_state.val_df)} ({100-train_size}%)")
247
+ st.write(f"Average question length: {st.session_state.train_df['Question'].str.len().mean():.1f} characters")
248
+ st.write(f"Average answer length: {st.session_state.train_df['Answer'].str.len().mean():.1f} characters")
249
+
250
+ if __name__ == "__main__":
251
+ st.set_page_config(
252
+ page_title="LLM Dataset Generator",
253
+ page_icon="πŸ“š",
254
+ layout="wide"
255
+ )
256
+ main()