Spaces:
Sleeping
Sleeping
Nathan Slaughter
commited on
Commit
•
4d17caa
1
Parent(s):
b8a0d78
add pytorch manual method
Browse files- .github/workflows/python-app.yaml +29 -0
- .gitignore +1 -0
- app.py +8 -0
- app/__init__.py +0 -0
- app/interface.py +113 -0
- app/models.py +31 -0
- app/processing.py +95 -0
- environment.yml +19 -0
- pytest.ini +5 -0
- requirements.txt +7 -0
- tests/__init__.py +0 -0
- tests/conftest.py +14 -0
- tests/test_models.py +20 -0
- tests/test_processing.py +73 -0
.github/workflows/python-app.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# .github/workflows/python-app.yml
|
2 |
+
|
3 |
+
name: Python application
|
4 |
+
|
5 |
+
on:
|
6 |
+
push:
|
7 |
+
branches: [ main ]
|
8 |
+
pull_request:
|
9 |
+
branches: [ main ]
|
10 |
+
|
11 |
+
jobs:
|
12 |
+
build:
|
13 |
+
|
14 |
+
runs-on: ubuntu-latest
|
15 |
+
|
16 |
+
steps:
|
17 |
+
- uses: actions/checkout@v2
|
18 |
+
- name: Set up Python
|
19 |
+
uses: actions/setup-python@v2
|
20 |
+
with:
|
21 |
+
python-version: '3.8'
|
22 |
+
- name: Install dependencies
|
23 |
+
run: |
|
24 |
+
python -m pip install --upgrade pip
|
25 |
+
pip install -r requirements.txt
|
26 |
+
pip install pytest pytest-mock
|
27 |
+
- name: Run tests
|
28 |
+
run: |
|
29 |
+
pytest
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app.interface import create_interface
|
2 |
+
|
3 |
+
def main():
|
4 |
+
interface = create_interface()
|
5 |
+
interface.launch()
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
main()
|
app/__init__.py
ADDED
File without changes
|
app/interface.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from .models import LanguageModel
|
3 |
+
from .processing import process_file, process_text_input
|
4 |
+
|
5 |
+
def create_interface():
|
6 |
+
# Initialize the language model
|
7 |
+
language_model = LanguageModel()
|
8 |
+
|
9 |
+
# Define the Output Format Selector
|
10 |
+
output_format_selector = gr.Radio(
|
11 |
+
choices=["CSV", "JSON"],
|
12 |
+
label="Select Output Format",
|
13 |
+
value="JSON",
|
14 |
+
type="value"
|
15 |
+
)
|
16 |
+
|
17 |
+
# Define the Output Flashcards
|
18 |
+
flashcard_output_file = gr.Textbox(
|
19 |
+
label="Flashcards",
|
20 |
+
lines=20,
|
21 |
+
placeholder="Extracted flashcards will appear here..."
|
22 |
+
)
|
23 |
+
flashcard_output_text = gr.Textbox(
|
24 |
+
label="Flashcards",
|
25 |
+
lines=20,
|
26 |
+
placeholder="Extracted flashcards will appear here..."
|
27 |
+
)
|
28 |
+
|
29 |
+
# Define the Gradio interface function for File Upload
|
30 |
+
def handle_file_upload(file_obj, output_format):
|
31 |
+
try:
|
32 |
+
flashcards = process_file(file_obj, output_format, language_model)
|
33 |
+
return flashcards
|
34 |
+
except ValueError as ve:
|
35 |
+
return str(ve)
|
36 |
+
|
37 |
+
# Define the Gradio interface function for Text Input
|
38 |
+
def handle_text_input(input_text, output_format):
|
39 |
+
try:
|
40 |
+
flashcards = process_text_input(input_text, output_format, language_model)
|
41 |
+
return flashcards
|
42 |
+
except ValueError as ve:
|
43 |
+
return str(ve)
|
44 |
+
|
45 |
+
# Create the Gradio Tabs
|
46 |
+
with gr.Blocks() as interface:
|
47 |
+
gr.Markdown("# Flashcard Extraction Tool")
|
48 |
+
gr.Markdown(
|
49 |
+
"Extract flashcards from uploaded files or directly input text. Choose your preferred output format."
|
50 |
+
)
|
51 |
+
with gr.Tab("Upload File"):
|
52 |
+
with gr.Row():
|
53 |
+
with gr.Column():
|
54 |
+
file_input = gr.File(
|
55 |
+
label="Upload a File",
|
56 |
+
file_types=['.pdf', '.txt', '.md']
|
57 |
+
)
|
58 |
+
format_selector = gr.Radio(
|
59 |
+
choices=["CSV", "JSON"],
|
60 |
+
label="Select Output Format",
|
61 |
+
value="JSON",
|
62 |
+
type="value"
|
63 |
+
)
|
64 |
+
submit_file = gr.Button("Extract Flashcards")
|
65 |
+
with gr.Column():
|
66 |
+
flashcard_output_file = gr.Textbox(
|
67 |
+
label="Flashcards",
|
68 |
+
lines=20,
|
69 |
+
placeholder="Extracted flashcards will appear here..."
|
70 |
+
)
|
71 |
+
submit_file.click(
|
72 |
+
fn=handle_file_upload,
|
73 |
+
inputs=[file_input, format_selector],
|
74 |
+
outputs=flashcard_output_file
|
75 |
+
)
|
76 |
+
|
77 |
+
with gr.Tab("Input Text"):
|
78 |
+
with gr.Row():
|
79 |
+
with gr.Column():
|
80 |
+
text_input = gr.Textbox(
|
81 |
+
label="Enter Text",
|
82 |
+
lines=20,
|
83 |
+
placeholder="Type or paste your text here..."
|
84 |
+
)
|
85 |
+
format_selector_text = gr.Radio(
|
86 |
+
choices=["CSV", "JSON"],
|
87 |
+
label="Select Output Format",
|
88 |
+
value="JSON",
|
89 |
+
type="value"
|
90 |
+
)
|
91 |
+
submit_text = gr.Button("Extract Flashcards")
|
92 |
+
with gr.Column():
|
93 |
+
flashcard_output_text = gr.Textbox(
|
94 |
+
label="Flashcards",
|
95 |
+
lines=20,
|
96 |
+
placeholder="Extracted flashcards will appear here..."
|
97 |
+
)
|
98 |
+
submit_text.click(
|
99 |
+
fn=handle_text_input,
|
100 |
+
inputs=[text_input, format_selector_text],
|
101 |
+
outputs=flashcard_output_text
|
102 |
+
)
|
103 |
+
|
104 |
+
gr.Markdown(
|
105 |
+
"""
|
106 |
+
---
|
107 |
+
**Notes:**
|
108 |
+
- Supported file types: `.pdf`, `.txt`, `.md`.
|
109 |
+
- Ensure that the input text is clear and well-structured for optimal flashcard extraction.
|
110 |
+
"""
|
111 |
+
)
|
112 |
+
|
113 |
+
return interface
|
app/models.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
|
4 |
+
class LanguageModel:
|
5 |
+
def __init__(self, model_name: str = "Qwen/Qwen2.5-7B-Instruct"):
|
6 |
+
self.device = self._determine_device()
|
7 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
8 |
+
model_name,
|
9 |
+
torch_dtype="auto",
|
10 |
+
device_map="auto"
|
11 |
+
)
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
13 |
+
|
14 |
+
def _determine_device(self):
|
15 |
+
if torch.cuda.is_available():
|
16 |
+
return torch.device("cuda")
|
17 |
+
elif torch.backends.mps.is_available():
|
18 |
+
return torch.device("mps")
|
19 |
+
else:
|
20 |
+
return torch.device("cpu")
|
21 |
+
|
22 |
+
def generate_flashcards(self, prompt: str, max_new_tokens: int = 1024) -> str:
|
23 |
+
inputs = self.tokenizer(prompt, return_tensors='pt').to(self.model.device)
|
24 |
+
with torch.no_grad():
|
25 |
+
output_ids = self.model.generate(
|
26 |
+
inputs.input_ids,
|
27 |
+
max_new_tokens=max_new_tokens,
|
28 |
+
do_sample=True
|
29 |
+
)
|
30 |
+
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
31 |
+
return response
|
app/processing.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pymupdf4llm
|
3 |
+
|
4 |
+
def process_pdf(pdf_path: str) -> str:
|
5 |
+
"""
|
6 |
+
Extracts text from a PDF file using pymupdf4llm.
|
7 |
+
"""
|
8 |
+
try:
|
9 |
+
text = pymupdf4llm.extract_text(pdf_path)
|
10 |
+
return text
|
11 |
+
except Exception as e:
|
12 |
+
raise ValueError(f"Error processing PDF: {str(e)}")
|
13 |
+
|
14 |
+
def read_text_file(file_path: str) -> str:
|
15 |
+
"""
|
16 |
+
Reads text from a .txt or .md file.
|
17 |
+
"""
|
18 |
+
try:
|
19 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
20 |
+
text = f.read()
|
21 |
+
return text
|
22 |
+
except Exception as e:
|
23 |
+
raise ValueError(f"Error reading text file: {str(e)}")
|
24 |
+
|
25 |
+
def format_prompt(output_format: str) -> str:
|
26 |
+
"""
|
27 |
+
Formats the prompt based on the output type.
|
28 |
+
"""
|
29 |
+
if output_format.lower() == "json":
|
30 |
+
return """You only respond with cards in JSON format. Follow the example below.
|
31 |
+
|
32 |
+
EXAMPLE:
|
33 |
+
[
|
34 |
+
{"question": "What is AI?", "answer": "Artificial Intelligence."},
|
35 |
+
{"question": "What is ML?", "answer": "Machine Learning."}
|
36 |
+
...
|
37 |
+
]
|
38 |
+
"""
|
39 |
+
elif output_format.lower() == "csv":
|
40 |
+
return """You only respond with cards in CSV format. Follow the example below.
|
41 |
+
|
42 |
+
EXAMPLE:
|
43 |
+
"What is AI?", "Artificial Intelligence."
|
44 |
+
"What is ML?", "Machine Learning."
|
45 |
+
...
|
46 |
+
"""
|
47 |
+
|
48 |
+
def extract_flashcards(text: str, output_format: str, language_model: str) -> str:
|
49 |
+
"""
|
50 |
+
Extracts flashcards from the input text using the LLM and formats them in CSV or JSON.
|
51 |
+
"""
|
52 |
+
prompt = f"""You are an expert flashcard creator. You always include a single knowledge item per flashcard.
|
53 |
+
|
54 |
+
{format_prompt(output_format)}
|
55 |
+
|
56 |
+
|
57 |
+
Extract flashcards from the user's text:
|
58 |
+
|
59 |
+
{text}
|
60 |
+
|
61 |
+
Do not include the prompt or any other unnecessary information in the flashcards.
|
62 |
+
Do not include triple ticks (```) or any other code blocks in the flashcards.
|
63 |
+
"""
|
64 |
+
# TODO:
|
65 |
+
# see https://qwen.readthedocs.io/en/latest/inference/chat.html
|
66 |
+
# e.g. pipeline = pipeline("text-generation", model="Qwen/Qwen2.5-7B-Instruct")
|
67 |
+
response = language_model.generate_flashcards(prompt)
|
68 |
+
return response
|
69 |
+
|
70 |
+
def process_file(file_obj, output_format: str, language_model) -> str:
|
71 |
+
"""
|
72 |
+
Processes the uploaded file based on its type and extracts flashcards.
|
73 |
+
"""
|
74 |
+
file_path = file_obj.name
|
75 |
+
file_ext = os.path.splitext(file_path)[1].lower()
|
76 |
+
|
77 |
+
if file_ext == '.pdf':
|
78 |
+
text = process_pdf(file_path)
|
79 |
+
elif file_ext in ['.txt', '.md']:
|
80 |
+
text = read_text_file(file_path)
|
81 |
+
else:
|
82 |
+
raise ValueError("Unsupported file type.")
|
83 |
+
|
84 |
+
flashcards = extract_flashcards(text, output_format, language_model)
|
85 |
+
return flashcards
|
86 |
+
|
87 |
+
def process_text_input(input_text: str, output_format: str, language_model) -> str:
|
88 |
+
"""
|
89 |
+
Processes the input text and extracts flashcards.
|
90 |
+
"""
|
91 |
+
if not input_text.strip():
|
92 |
+
raise ValueError("No text provided.")
|
93 |
+
|
94 |
+
flashcards = extract_flashcards(input_text, output_format, language_model)
|
95 |
+
return flashcards
|
environment.yml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: flashcard-maker
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- pytorch
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- python=3.12
|
8 |
+
- torch
|
9 |
+
- torchvision
|
10 |
+
- torchaudio
|
11 |
+
- cudatoolkit=11.7 # Remove or adjust if installing CPU-only
|
12 |
+
- transformers
|
13 |
+
- gradio
|
14 |
+
- librosa
|
15 |
+
- pytest
|
16 |
+
- pytest-mock
|
17 |
+
- pip
|
18 |
+
- pip:
|
19 |
+
- pymupdf4llm
|
pytest.ini
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytest.ini
|
2 |
+
|
3 |
+
[pytest]
|
4 |
+
filterwarnings =
|
5 |
+
ignore::DeprecationWarning
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytorch
|
2 |
+
transformers
|
3 |
+
gradio
|
4 |
+
librosa
|
5 |
+
pymupdf4llm
|
6 |
+
pytest
|
7 |
+
pytest-mock # Added for mocking capabilities
|
tests/__init__.py
ADDED
File without changes
|
tests/conftest.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from unittest.mock import Mock
|
3 |
+
from app.models import LanguageModel
|
4 |
+
|
5 |
+
@pytest.fixture
|
6 |
+
def language_model():
|
7 |
+
"""
|
8 |
+
Fixture to provide a mocked LanguageModel instance.
|
9 |
+
"""
|
10 |
+
# Create a mock instance of LanguageModel
|
11 |
+
lm = Mock(spec=LanguageModel)
|
12 |
+
# Mock the generate_flashcards method
|
13 |
+
lm.generate_flashcards.return_value = '{"flashcards": []}'
|
14 |
+
return lm
|
tests/test_models.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# tests/test_models.py
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
|
5 |
+
def test_generate_flashcards(language_model, mocker):
|
6 |
+
"""
|
7 |
+
Test the generate_flashcards method of LanguageModel.
|
8 |
+
"""
|
9 |
+
prompt = "Sample prompt for flashcard generation."
|
10 |
+
expected_response = '{"flashcards": [{"Question": "What is AI?", "Answer": "Artificial Intelligence."}]}'
|
11 |
+
|
12 |
+
# Configure the mock to return a specific response
|
13 |
+
language_model.generate_flashcards.return_value = expected_response
|
14 |
+
|
15 |
+
# Call the method
|
16 |
+
response = language_model.generate_flashcards(prompt)
|
17 |
+
|
18 |
+
# Assertions
|
19 |
+
assert response == expected_response
|
20 |
+
language_model.generate_flashcards.assert_called_once_with(prompt)
|
tests/test_processing.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# tests/test_processing.py
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
from app.processing import process_text_input, process_file
|
5 |
+
|
6 |
+
def test_process_text_input_success(language_model):
|
7 |
+
"""
|
8 |
+
Test processing of valid text input.
|
9 |
+
"""
|
10 |
+
input_text = "This is a sample text for flashcard extraction."
|
11 |
+
output_format = "JSON"
|
12 |
+
expected_output = '{"flashcards": []}'
|
13 |
+
|
14 |
+
result = process_text_input(input_text, output_format, language_model)
|
15 |
+
assert result == expected_output
|
16 |
+
language_model.generate_flashcards.assert_called_once()
|
17 |
+
|
18 |
+
def test_process_text_input_empty(language_model):
|
19 |
+
"""
|
20 |
+
Test processing of empty text input.
|
21 |
+
"""
|
22 |
+
input_text = " "
|
23 |
+
output_format = "JSON"
|
24 |
+
|
25 |
+
with pytest.raises(ValueError) as excinfo:
|
26 |
+
process_text_input(input_text, output_format, language_model)
|
27 |
+
assert "No text provided." in str(excinfo.value)
|
28 |
+
|
29 |
+
def test_process_file_unsupported_type(language_model, tmp_path):
|
30 |
+
"""
|
31 |
+
Test processing of an unsupported file type.
|
32 |
+
"""
|
33 |
+
# Create a dummy unsupported file
|
34 |
+
dummy_file = tmp_path / "dummy.unsupported"
|
35 |
+
dummy_file.write_text("Unsupported content")
|
36 |
+
|
37 |
+
with pytest.raises(ValueError) as excinfo:
|
38 |
+
process_file(dummy_file, "JSON", language_model)
|
39 |
+
assert "Unsupported file type." in str(excinfo.value)
|
40 |
+
|
41 |
+
def test_process_file_pdf(language_model, tmp_path, mocker):
|
42 |
+
"""
|
43 |
+
Test processing of a PDF file.
|
44 |
+
"""
|
45 |
+
# Mock the process_pdf function
|
46 |
+
mocker.patch('app.processing.process_pdf', return_value="Extracted PDF text.")
|
47 |
+
|
48 |
+
# Create a dummy PDF file
|
49 |
+
dummy_file = tmp_path / "test.pdf"
|
50 |
+
dummy_file.write_text("PDF content")
|
51 |
+
|
52 |
+
expected_output = '{"flashcards": []}'
|
53 |
+
|
54 |
+
result = process_file(dummy_file, "JSON", language_model)
|
55 |
+
assert result == expected_output
|
56 |
+
language_model.generate_flashcards.assert_called_once()
|
57 |
+
|
58 |
+
def test_process_file_txt(language_model, tmp_path, mocker):
|
59 |
+
"""
|
60 |
+
Test processing of a TXT file.
|
61 |
+
"""
|
62 |
+
# Mock the read_text_file function
|
63 |
+
mocker.patch('app.processing.read_text_file', return_value="Extracted TXT text.")
|
64 |
+
|
65 |
+
# Create a dummy TXT file
|
66 |
+
dummy_file = tmp_path / "test.txt"
|
67 |
+
dummy_file.write_text("TXT content")
|
68 |
+
|
69 |
+
expected_output = '{"flashcards": []}'
|
70 |
+
|
71 |
+
result = process_file(dummy_file, "JSON", language_model)
|
72 |
+
assert result == expected_output
|
73 |
+
language_model.generate_flashcards.assert_called_once()
|