Maarten Van Segbroeck
Create app.py
4bd2901 verified
raw
history blame
14.4 kB
import gradio as gr
import requests
import os
import markdownify
import fitz # PyMuPDF
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pandas as pd
import random
from gretel_client import Gretel
from gretel_client.config import GretelClientConfigurationError
# Directory for saving processed PDFs
output_dir = 'processed_pdfs'
os.makedirs(output_dir, exist_ok=True)
# Function to download and convert a PDF to text
def pdf_to_text(pdf_path):
pdf_document = fitz.open(pdf_path)
text = ''
for page_num in range(pdf_document.page_count):
page = pdf_document.load_page(page_num)
text += page.get_text()
return text
# Function to split text into chunks
def split_text_into_chunks(text, chunk_size=25, chunk_overlap=5, min_chunk_chars=50):
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = text_splitter.split_text(text)
return [chunk for chunk in chunks if len(chunk) >= min_chunk_chars]
# Function to save chunks to files
def save_chunks(file_id, chunks, output_dir):
for i, chunk in enumerate(chunks):
chunk_filename = f"{file_id}_chunk_{i+1}.md"
chunk_path = os.path.join(output_dir, chunk_filename)
with open(chunk_path, 'w') as file:
file.write(chunk)
# Function to read chunks from files
def read_chunks_from_files(output_dir):
pdf_chunks_dict = {}
for filename in os.listdir(output_dir):
if filename.endswith('.md') and 'chunk' in filename:
file_id = filename.split('_chunk_')[0]
chunk_path = os.path.join(output_dir, filename)
with open(chunk_path, 'r') as file:
chunk = file.read()
if file_id not in pdf_chunks_dict:
pdf_chunks_dict[file_id] = []
pdf_chunks_dict[file_id].append(chunk)
return pdf_chunks_dict
def process_pdfs(uploaded_files, use_example, chunk_size, chunk_overlap, min_chunk_chars, current_chunk, direction):
selected_pdfs = []
if use_example:
example_file_url = "https://gretel-datasets.s3.us-west-2.amazonaws.com/rag/GDPR_2016.pdf"
pdf_path = os.path.join(output_dir, example_file_url.split('/')[-1])
if not os.path.exists(pdf_path):
response = requests.get(example_file_url)
with open(pdf_path, 'wb') as file:
file.write(response.content)
selected_pdfs = [pdf_path]
elif uploaded_files is not None:
for uploaded_file in uploaded_files:
pdf_path = os.path.join(output_dir, uploaded_file.name)
selected_pdfs.append(pdf_path)
else:
chunk_text = "No PDFs processed"
return None, 0, chunk_text, None
pdf_chunks_dict = {}
for pdf_path in selected_pdfs:
text = pdf_to_text(pdf_path)
markdown_text = markdownify.markdownify(text)
file_id = os.path.splitext(os.path.basename(pdf_path))[0]
markdown_path = os.path.join(output_dir, f"{file_id}.md")
with open(markdown_path, 'w') as file:
file.write(markdown_text)
chunks = split_text_into_chunks(markdown_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap, min_chunk_chars=min_chunk_chars)
save_chunks(file_id, chunks, output_dir)
pdf_chunks_dict[file_id] = chunks
file_id = os.path.splitext(os.path.basename(selected_pdfs[0]))[0]
chunks = pdf_chunks_dict.get(file_id, [])
current_chunk += direction
if current_chunk < 0:
current_chunk = 0
elif current_chunk >= len(chunks):
current_chunk = len(chunks) - 1
chunk_text = chunks[current_chunk] if chunks else "No chunks available."
# use_example_update = gr.update(
# value=False,
# interactive=uploaded_files is None or len(uploaded_files) == 0
# )
return pdf_chunks_dict, selected_pdfs, chunk_text, current_chunk#, use_example_update
def show_chunks(pdf_chunks_dict, selected_pdfs, current_chunk, direction):
if selected_pdfs:
file_id = os.path.splitext(os.path.basename(selected_pdfs[0]))[0]
chunks = pdf_chunks_dict.get(file_id, [])
current_chunk += direction
if current_chunk < 0:
current_chunk = 0
elif current_chunk >= len(chunks):
current_chunk = len(chunks) - 1
chunk_text = chunks[current_chunk] if chunks else "No chunks available."
return chunk_text, current_chunk
else:
return "No PDF processed.", 0
# Validate API key and return button state
def check_api_key(api_key):
try:
Gretel(api_key=api_key, validate=True, clear=True)
is_valid = True
status_message = "Valid"
except GretelClientConfigurationError:
is_valid = False
status_message = "Invalid"
return gr.update(interactive=is_valid), status_message
def generate_synthetic_records(api_key, pdf_chunks_dict, num_records):
gretel = Gretel(api_key=api_key, validate=True, clear=True)
navigator = gretel.factories.initialize_inference_api("navigator")
INTRO_PROMPT = "From the source text below, create a dataset with the following columns:\n"
COLUMN_DETAILS = (
"* `topic`: Select a topic relevant for the given source text.\n"
"* `user_profile`: The complexity level of the question and truth, categorized into beginner, intermediate, and expert.\n"
" - Beginner users are about building foundational knowledge about the product and ask about basic features, benefits, and uses of the product.\n"
" - Intermediate users have a deep understanding of the product and are focusing on practical applications, comparisons with other products, and intermediate-level features and benefits.\n"
" - Expert users demonstrate in-depth technical knowledge, strategic application, and advanced troubleshooting. This level is for those who need to know the product inside and out, possibly for roles in sales, technical support, or product development.\n"
"* `question`: Ask a set of unique questions related to the topic that a user might ask. "
"Questions should be relatively complex and specific enough to be addressed in a short answer.\n"
"* `answer`: Respond to the question with a clear, textbook quality answer that provides relevant details to fully address the question.\n"
"* `context`: Copy the exact sentence(s) from the source text and surrounding details from where the answer can be derived.\n"
)
PROMPT = INTRO_PROMPT + COLUMN_DETAILS
GENERATE_PARAMS = {
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40
}
df_in = pd.DataFrame()
documents = list(pdf_chunks_dict.keys())
all_chunks = [(doc, chunk) for doc in documents for chunk in pdf_chunks_dict[doc]]
for _ in range(num_records):
doc, chunk = random.choice(all_chunks)
df_doc = pd.DataFrame({'document': [doc], 'text': [chunk]})
df_in = pd.concat([df_in, df_doc], ignore_index=True)
df = navigator.edit(PROMPT, seed_data=df_in, **GENERATE_PARAMS)
df = df.drop(columns=['text'])
return gr.update(value=df, visible=True)
# CSS styling to center the logo and prevent right-click download
css = """
<style>
#logo-container {
display: flex;
justify-content: center;
width: 100%;
}
#logo-container svg {
pointer-events: none; /* Disable pointer events on the SVG */
}
</style>
"""
# HTML content to include the logo
html_content = f"""
{css}
<div id="logo-container">
<svg width="181" height="72" viewBox="0 0 181 72" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_849_78)">
<path d="M53.4437 41.3178V53.5794H44.4782V18.8754H53.4437V27.0498C55.1339 21.1048 58.9552 18.1323 63.144 18.1323C65.3487 18.1323 67.2593 18.5782 68.8025 19.3956L67.2593 27.57C65.863 26.9011 64.0993 26.604 62.0417 26.604C56.3097 26.604 53.4437 31.5085 53.4437 41.3178Z" fill="#3C1AE6"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M103.383 45.9252C100.444 51.573 94.1975 54.3226 87.3631 54.3226C82.366 54.3226 78.1773 52.6134 74.7234 49.2693C71.2694 45.8509 69.5793 41.4664 69.5793 36.1159C69.5793 30.7654 71.2694 26.4553 74.7234 23.1112C78.1773 19.7671 82.366 18.1323 87.3631 18.1323C92.3603 18.1323 96.4755 19.7671 99.7824 23.1112C103.089 26.4553 104.78 30.7654 104.78 36.1159C104.78 37.019 104.715 37.987 104.647 39.0198L104.633 39.2371H78.4712C79.0591 43.7701 82.8805 46.6684 87.951 46.6684C91.5519 46.6684 95.0058 45.1078 96.549 42.2097L103.383 45.9252ZM78.3978 33.0691H96.0346C95.3733 28.3875 91.9194 25.7122 87.4366 25.7122C82.66 25.7122 79.0591 28.5361 78.3978 33.0691Z" fill="#3C1AE6"/>
<path d="M121.87 26.158V53.5794H112.979V26.158H106.732V18.8754H112.979V5.64777H121.87V18.8754H129.146V26.158H121.87Z" fill="#3C1AE6"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M164.903 45.9252C161.963 51.573 155.716 54.3226 148.882 54.3226C143.885 54.3226 139.696 52.6134 136.242 49.2693C132.789 45.8509 131.098 41.4664 131.098 36.1159C131.098 30.7654 132.789 26.4553 136.242 23.1112C139.696 19.7671 143.885 18.1323 148.882 18.1323C153.879 18.1323 157.994 19.7671 161.301 23.1112C164.609 26.4553 166.299 30.7654 166.299 36.1159C166.299 37.0174 166.235 37.9834 166.167 39.0141L166.152 39.2371H139.99C140.578 43.7701 144.399 46.6684 149.47 46.6684C153.072 46.6684 156.525 45.1078 158.069 42.2097L164.903 45.9252ZM139.917 33.0691H157.554C156.893 28.3875 153.439 25.7122 148.956 25.7122C144.179 25.7122 140.578 28.5361 139.917 33.0691Z" fill="#3C1AE6"/>
<path d="M180.597 0V53.5794H171.631V0H180.597Z" fill="#3C1AE6"/>
<path d="M27.1716 19.3782C27.1716 14.947 30.7321 11.3548 35.1241 11.3548V19.3782C35.1764 19.3959 27.1716 19.3782 27.1716 19.3782Z" fill="#3C1AE6"/>
<path d="M34.7984 54.5253C34.7984 64.11 27.2527 71.9206 17.8936 71.9206C8.62804 71.9206 1.13987 64.2655 0.991031 54.8122L0.988777 54.5253H8.94397C8.94397 59.7209 12.9737 63.8921 17.8936 63.8921C22.746 63.8921 26.7325 59.8342 26.8409 54.7381L26.8431 54.5253H34.7984Z" fill="#3C1AE6"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M35.7872 36.4522C35.7872 26.4724 27.7758 18.3822 17.8936 18.3822C8.01121 18.3822 0 26.4724 0 36.4522C0 46.4322 8.01121 54.5224 17.8936 54.5224C27.7758 54.5224 35.7872 46.4322 35.7872 36.4522ZM8.61542 36.4522C8.61542 31.2775 12.7694 27.0826 17.8936 27.0826C23.0178 27.0826 27.1716 31.2775 27.1716 36.4522C27.1716 41.6271 23.0178 45.822 17.8936 45.822C12.7694 45.822 8.61542 41.6271 8.61542 36.4522Z" fill="#3C1AE6"/>
</g>
<defs>
<clipPath id="clip0_849_78">
<rect width="181" height="72" fill="white"/>
</clipPath>
</defs>
</svg>
</div>
"""
# Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=3):
# gr.Markdown("# Upload PDFs")
# gr.Image("gretel_logo.svg", elem_id="logo", show_label=False)
gr.HTML(html_content)
with gr.Tab("Upload PDF"):
use_example = gr.Checkbox(label="Continue with Example PDF", value=False, interactive=True)
uploaded_files = gr.File(label="Upload your PDF files", file_count="multiple")
# if uploaded_files:
# use_example = gr.Checkbox(label="Continue with Example PDF", value=False, interactive=False)
chunk_size = gr.Slider(label="Chunk Size (tokens)", minimum=10, maximum=1500, step=10, value=500)
chunk_overlap = gr.Slider(label="Chunk Overlap (tokens)", minimum=0, maximum=500, step=5, value=100)
min_chunk_chars = gr.Slider(label="Minimum Chunk Characters", minimum=10, maximum=2500, step=10, value=750)
process_button = gr.Button("Process PDFs")
pdf_chunks_dict = gr.State()
selected_pdfs = gr.State()
current_chunk = gr.State(value=0)
chunk_text = gr.Textbox(label="Chunk Text", lines=10)
def toggle_use_example(file_list):
return gr.update(
value = False,
interactive=file_list is None or len(file_list) == 0
)
uploaded_files.change(
toggle_use_example,
inputs=[uploaded_files],
outputs=[use_example]
)
process_button.click(
process_pdfs,
inputs=[uploaded_files, use_example, chunk_size, chunk_overlap, min_chunk_chars, current_chunk, gr.State(0)],
outputs=[pdf_chunks_dict, selected_pdfs, chunk_text, current_chunk]
)
with gr.Row():
prev_button = gr.Button("Previous Chunk", scale=1)
next_button = gr.Button("Next Chunk", scale=1)
prev_button.click(
show_chunks,
inputs=[pdf_chunks_dict, selected_pdfs, current_chunk, gr.State(-1)],
outputs=[chunk_text, current_chunk]
)
next_button.click(
show_chunks,
inputs=[pdf_chunks_dict, selected_pdfs, current_chunk, gr.State(1)],
outputs=[chunk_text, current_chunk]
)
with gr.Column(scale=7):
gr.Markdown("# Generate Question-Answer pairs")
with gr.Row():
api_key_input = gr.Textbox(label="API Key", type="password", placeholder="Enter your API key", scale=2)
validate_status = gr.Textbox(label="Validation Status", interactive=False, scale=1)
# User-specific settings
num_records = gr.Number(label="Number of Records", value=10)
generate_button = gr.Button("Generate Synthetic Records", interactive=False)
# Validate API key on input change and update button interactivity
api_key_input.change(
fn=check_api_key,
inputs=[api_key_input],
outputs=[generate_button, validate_status]
)
output_df = gr.Dataframe(headers=["document", "topic", "user_profile", "question", "answer", "context"], wrap=True, visible=True)
generate_button.click(
fn=generate_synthetic_records,
inputs=[api_key_input, pdf_chunks_dict, num_records],
outputs=[output_df]
)
demo.launch()