File size: 9,135 Bytes
e6c245e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import streamlit as st
import pymupdf4llm
import pandas as pd
from groq import Groq
import json
import tempfile
from sklearn.model_selection import train_test_split

# Initialize session state variables if they don't exist
if 'train_df' not in st.session_state:
    st.session_state.train_df = None
if 'val_df' not in st.session_state:
    st.session_state.val_df = None
if 'generated' not in st.session_state:
    st.session_state.generated = False
if 'previous_upload_state' not in st.session_state:
    st.session_state.previous_upload_state = False

def reset_session_state():
    """Reset all relevant session state variables"""
    st.session_state.train_df = None
    st.session_state.val_df = None
    st.session_state.generated = False

def parse_pdf(uploaded_file) -> str:
    with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
        tmp_file.write(uploaded_file.getvalue())
        tmp_file.seek(0)
        
        text = pymupdf4llm.to_markdown(tmp_file.name)

    return text

def generate_qa_pairs(text: str, api_key: str, model: str, num_pairs: int, context: str) -> pd.DataFrame:
    client = Groq(api_key=api_key)
    
    prompt = f"""
    Given the following text, generate {num_pairs} question-answer pairs:

    {text}

    Format each pair as:
    Q: [Question]
    A: [Answer]

    Ensure the questions are diverse and cover different aspects of the text.
    """

    try:
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a helpful assistant that generates question-answer pairs based on given text."},
                {"role": "user", "content": prompt}
            ]
        )

        qa_text = response.choices[0].message.content
        qa_pairs = []
        
        for pair in qa_text.split('\n\n'):
            if pair.startswith('Q:') and 'A:' in pair:
                question, answer = pair.split('A:')
                question = question.replace('Q:', '').strip()
                answer = answer.strip()
                qa_pairs.append({
                    'Question': question,
                    'Answer': answer,
                    'Context': context
                })

        return pd.DataFrame(qa_pairs)
    
    except Exception as e:
        st.error(f"Error generating QA pairs: {str(e)}")
        return pd.DataFrame()

def create_jsonl_content(df: pd.DataFrame, system_content: str) -> str:
    """Convert DataFrame to JSONL string content"""
    jsonl_content = []
    for _, row in df.iterrows():
        entry = {
            "messages": [
                {"role": "system", "content": system_content},
                {"role": "user", "content": row['Question']},
                {"role": "assistant", "content": row['Answer']}
            ]
        }
        jsonl_content.append(json.dumps(entry, ensure_ascii=False))
    return '\n'.join(jsonl_content)

def process_and_split_data(text: str, api_key: str, model: str, num_pairs: int, context: str, train_size: float):
    """Process data and store results in session state"""
    df = generate_qa_pairs(text, api_key, model, num_pairs, context)
    
    if not df.empty:
        # Split the dataset
        train_df, val_df = train_test_split(
            df, 
            train_size=train_size/100,
            random_state=42
        )
        
        # Store in session state
        st.session_state.train_df = train_df
        st.session_state.val_df = val_df
        st.session_state.generated = True
        return True
    return False

def main():
    st.title("LLM Dataset Generator")
    st.write("Upload a PDF file and generate training & validation sets of question-answer pairs of your data using LLM.")

    # Sidebar configurations
    st.sidebar.header("Configuration")
    
    api_key = st.sidebar.text_input("Enter Groq API Key", type="password")
    
    model = st.sidebar.selectbox(
        "Select Model",
        ["llama3-8b-8192", "llama3-70b-8192", "mixtral-8x7b-32768", "gemma2-9b-it"]
    )
    
    num_pairs = st.sidebar.number_input(
        "Number of QA Pairs",
        min_value=1,
        max_value=10000,
        value=5
    )
    
    context = st.sidebar.text_area(
        "Custom Context",
        value="Write a response that appropriately completes the request.",
        help="This text will be added to the Context column for each QA pair.",
        placeholder= "Add custom context here."
    )

    # Dataset split configuration
    st.sidebar.header("Dataset Split")
    train_size = st.sidebar.slider(
        "Training Set Size (%)",
        min_value=50,
        max_value=90,
        value=80,
        step=5
    )

    # Output format configuration
    st.sidebar.header("Output Format")
    output_format = st.sidebar.selectbox(
        "Select Output Format",
        ["CSV", "JSONL"]
    )

    if output_format == "JSONL":
        system_content = st.sidebar.text_area(
            "System Message",
            value="You are a helpful assistant that provides accurate and informative answers.",
            help="This message will be used as the system content in the JSONL format."
        )

    # Main area
    uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")

    # Check if upload state has changed
    current_upload_state = uploaded_file is not None
    if current_upload_state != st.session_state.previous_upload_state:
        if not current_upload_state:  # File was removed
            reset_session_state()
        st.session_state.previous_upload_state = current_upload_state

    if uploaded_file is not None:
        if not api_key:
            st.warning("Please enter your Groq API key in the sidebar.")
            return

        text = parse_pdf(uploaded_file)
        st.success("PDF processed successfully!")
        
        if st.button("Generate QA Pairs"):
            with st.spinner("Generating QA pairs..."):
                success = process_and_split_data(text, api_key, model, num_pairs, context, train_size)
                if success:
                    st.success("QA pairs generated successfully!")

    # Display results if data has been generated
    if st.session_state.generated and st.session_state.train_df is not None and st.session_state.val_df is not None:
        # Display the dataframes
        st.subheader("Training Set")
        st.dataframe(st.session_state.train_df)
        
        st.subheader("Validation Set")
        st.dataframe(st.session_state.val_df)
        
        # Create download section
        st.subheader("Download Generated Datasets")
        col1, col2 = st.columns(2)
        
        with col1:
            st.markdown("##### Training Set")
            if output_format == "CSV":
                train_csv = st.session_state.train_df.to_csv(index=False)
                st.download_button(
                    label="Download Training Set (CSV)",
                    data=train_csv,
                    file_name="train_qa_pairs.csv",
                    mime="text/csv",
                    key="train_csv"
                )
            else:  # JSONL format
                train_jsonl = create_jsonl_content(st.session_state.train_df, system_content)
                st.download_button(
                    label="Download Training Set (JSONL)",
                    data=train_jsonl,
                    file_name="train_qa_pairs.jsonl",
                    mime="application/jsonl",
                    key="train_jsonl"
                )
        
        with col2:
            st.markdown("##### Validation Set")
            if output_format == "CSV":
                val_csv = st.session_state.val_df.to_csv(index=False)
                st.download_button(
                    label="Download Validation Set (CSV)",
                    data=val_csv,
                    file_name="val_qa_pairs.csv",
                    mime="text/csv",
                    key="val_csv"
                )
            else:  # JSONL format
                val_jsonl = create_jsonl_content(st.session_state.val_df, system_content)
                st.download_button(
                    label="Download Validation Set (JSONL)",
                    data=val_jsonl,
                    file_name="val_qa_pairs.jsonl",
                    mime="application/jsonl",
                    key="val_jsonl"
                )
        
        # Display statistics
        st.subheader("Statistics")
        st.write(f"Total QA pairs: {len(st.session_state.train_df) + len(st.session_state.val_df)}")
        st.write(f"Training set size: {len(st.session_state.train_df)} ({train_size}%)")
        st.write(f"Validation set size: {len(st.session_state.val_df)} ({100-train_size}%)")
        st.write(f"Average question length: {st.session_state.train_df['Question'].str.len().mean():.1f} characters")
        st.write(f"Average answer length: {st.session_state.train_df['Answer'].str.len().mean():.1f} characters")

if __name__ == "__main__":
    st.set_page_config(
        page_title="LLM Dataset Generator",
        page_icon="📚",
        layout="wide"
    )
    main()