NER / app.py
nafees369's picture
Update app.py
3e35703 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from sentence_transformers import SentenceTransformer, util
import fitz # PyMuPDF for PDF handling
import torch
import docx # For DOCX handling
# Load pre-trained models
model_name = "dbmdz/bert-large-cased-finetuned-conll03-english"
ner_model = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=tokenizer, aggregation_strategy="simple")
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Function to extract text from a PDF file with error handling
def extract_text_from_pdf(file_path):
try:
doc = fitz.open(file_path)
text = ""
for page in doc:
text += page.get_text()
return text.strip()
except Exception as e:
return f"Error extracting text from PDF: {str(e)}"
# Function to extract text from a DOCX file
def extract_text_from_docx(file_path):
try:
doc = docx.Document(file_path)
text = "\n".join([para.text for para in doc.paragraphs])
return text.strip()
except Exception as e:
return f"Error extracting text from DOCX: {str(e)}"
# Function to calculate cosine similarity
def calculate_similarity(input_label, predefined_labels):
input_embedding = embedding_model.encode(input_label, convert_to_tensor=True)
predefined_embeddings = embedding_model.encode(predefined_labels, convert_to_tensor=True)
cosine_scores = util.pytorch_cos_sim(input_embedding, predefined_embeddings)
best_match_idx = torch.argmax(cosine_scores).item()
return predefined_labels[best_match_idx], cosine_scores[0][best_match_idx].item()
# Function to map recognized entities to custom labels with cosine similarity
def map_labels_with_similarity(input_label, label_map):
predefined_labels = list(label_map.keys())
best_match_label, similarity_score = calculate_similarity(input_label, predefined_labels)
if similarity_score > 0.7: # Threshold for considering a match
return best_match_label
return None
# Function to process the text and extract entities based on custom labels
def process_text(file, labels):
# Determine the file type and extract text accordingly
if file.name.endswith(".pdf"):
text = extract_text_from_pdf(file.name)
elif file.name.endswith(".docx"):
text = extract_text_from_docx(file.name)
else:
return "Unsupported file type. Please upload a PDF or DOCX file."
if text.startswith("Error"):
return text # Return the error message if text extraction failed
# Define the custom label mapping
label_map = {
"Name": ["PER"],
"Organization": ["ORG"],
"Location": ["LOC"],
"Address": ["LOC"], # Address mapped to Location
"Project": ["MISC"],
"Education": ["MISC"],
}
# Split the custom labels provided by the user and handle potential input issues
requested_labels = [label.strip().capitalize() for label in labels.split(",") if label.strip()]
if not requested_labels:
return "No valid labels provided. Please enter valid labels to extract."
# Initialize a dictionary to hold the extracted information
extracted_info = {label: [] for label in requested_labels}
# Perform NER on the extracted text
ner_results = ner_pipeline(text)
# Process the NER results
for entity in ner_results:
entity_text = entity['word'].replace("##", "")
entity_group = entity['entity_group']
# Determine the best matching label using cosine similarity
for input_label in requested_labels:
best_match_label = map_labels_with_similarity(input_label, label_map)
if best_match_label and entity_group in label_map[best_match_label]:
extracted_info[input_label].append(entity_text)
# Format the output
output = ""
for label, entities in extracted_info.items():
if entities:
# Remove duplicates and clean up the entities
unique_entities = sorted(set(entities))
output += f"{label}: {', '.join(unique_entities)}\n"
else:
output += f"{label}: No information found.\n"
return output.strip()
# Create Gradio components
file_input = gr.File(label="Upload a PDF or DOCX file")
label_input = gr.Textbox(label="Enter labels to extract (comma-separated)")
output_text = gr.Textbox(label="Extracted Information")
# Create the Gradio interface
iface = gr.Interface(
fn=process_text,
inputs=[file_input, label_input],
outputs=output_text,
title="NER with Custom Labels from PDF or DOCX",
description="Upload a PDF or DOCX file and extract entities based on custom labels."
)
# Launch the Gradio interface
iface.launch()