extract_api / main.py
sinan7's picture
Upload main.py
94a35f8 verified
raw
history blame contribute delete
No virus
4.98 kB
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks
from pydantic import BaseModel
import fitz # PyMuPDF
import tempfile, os, json, logging
from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration
# Initialize FastAPI app and logging
app = FastAPI()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base", legacy=False)
qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=-1) # CPU-based inference
# Define the structured response model
class Education(BaseModel):
degree: str
university: str
graduation_year: str
class ExtractedInfo(BaseModel):
work_experience: str
education: Education
professional_course_detail: str
software_usage: str
safety_course_detail: str
hse_description: str
good_conduct_certificate: str
def extract_text_from_pdf(pdf_path: str) -> str:
"""Extracts text from the uploaded PDF."""
with fitz.open(pdf_path) as doc:
return "".join(page.get_text() for page in doc).strip()
def chunk_text(text: str, max_tokens: int = 512) -> list:
"""Splits the text into manageable chunks that fit within the token limit."""
tokens = tokenizer.encode(text, add_special_tokens=False)
return [tokenizer.decode(tokens[i:i + max_tokens], skip_special_tokens=True)
for i in range(0, len(tokens), max_tokens)]
def process_chunk(chunk: str) -> dict:
"""Uses the model to extract structured JSON data from a text chunk."""
prompt = f"""
Extract the following information in JSON format:
{{
"work_experience": "<Summarized work experience>",
"education": {{
"degree": "<Degree obtained>",
"university": "<University attended>",
"graduation_year": "<Year of graduation>"
}},
"professional_course_detail": "<Professional courses>",
"software_usage": "<Software tools used>",
"safety_course_detail": "<Safety courses completed>",
"hse_description": "<HSE practices>",
"good_conduct_certificate": "<Good conduct certificate>"
}}
Resume text: {chunk}
"""
try:
response = qa_pipeline(prompt, max_new_tokens=150)
generated_text = response[0]["generated_text"]
# Extract JSON from the generated text
json_start = generated_text.find("{")
json_end = generated_text.rfind("}") + 1
return json.loads(generated_text[json_start:json_end])
except (json.JSONDecodeError, IndexError) as e:
logger.warning(f"Failed to parse JSON: {e}")
return {}
def merge_outputs(chunks: list) -> dict:
"""Combines multiple chunk outputs into a single structured result."""
merged = {
"work_experience": "",
"education": {"degree": "", "university": "", "graduation_year": ""},
"professional_course_detail": "",
"software_usage": "",
"safety_course_detail": "",
"hse_description": "",
"good_conduct_certificate": ""
}
for chunk in chunks:
chunk_output = process_chunk(chunk)
for key, value in chunk_output.items():
if isinstance(value, dict):
merged[key].update(value)
elif not merged[key]:
merged[key] = value
return merged
@app.post("/process_cv/", response_model=ExtractedInfo)
async def process_cv(file: UploadFile = File(...), background_tasks: BackgroundTasks = BackgroundTasks()):
"""Processes a PDF resume and returns structured information in JSON format."""
if not file.filename.endswith(".pdf"):
raise HTTPException(status_code=400, detail="Only PDF files are allowed.")
temp_path = tempfile.mktemp(suffix=".pdf")
with open(temp_path, "wb") as f:
f.write(await file.read())
try:
# Extract text from the uploaded PDF
text = extract_text_from_pdf(temp_path)
if not text:
raise HTTPException(status_code=400, detail="No extractable text found in the PDF.")
# Process the text in chunks and merge the output
chunks = chunk_text(text)
structured_data = merge_outputs(chunks)
return ExtractedInfo(**structured_data)
finally:
os.remove(temp_path)
background_tasks.add_task(clean_temp_files)
def clean_temp_files():
"""Cleans up temporary PDF files."""
for filename in os.listdir(tempfile.gettempdir()):
if filename.endswith(".pdf"):
try:
os.remove(os.path.join(tempfile.gettempdir(), filename))
except Exception as e:
logger.warning(f"Failed to delete {filename}: {e}")