nafees369 commited on
Commit
3e35703
1 Parent(s): bf5a713

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -39
app.py CHANGED
@@ -1,74 +1,115 @@
1
 
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
 
4
  import fitz # PyMuPDF for PDF handling
 
 
5
 
6
- # Load a pre-trained NER model
7
  model_name = "dbmdz/bert-large-cased-finetuned-conll03-english"
8
- model = AutoModelForTokenClassification.from_pretrained(model_name)
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
 
 
11
 
12
- # Function to extract text from a PDF file
13
  def extract_text_from_pdf(file_path):
14
- doc = fitz.open(file_path)
15
- text = ""
16
- for page in doc:
17
- text += page.get_text()
18
- return text.strip()
19
-
20
- # Function to map recognized entities to custom labels
21
- def map_labels(entity_label, label_map):
22
- for custom_label, ner_labels in label_map.items():
23
- if entity_label in ner_labels:
24
- return custom_label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  return None
26
 
27
  # Function to process the text and extract entities based on custom labels
28
  def process_text(file, labels):
29
- # Extract text from the PDF file
30
- text = extract_text_from_pdf(file.name)
31
-
 
 
 
 
 
 
 
 
32
  # Define the custom label mapping
33
  label_map = {
34
  "Name": ["PER"],
35
  "Organization": ["ORG"],
36
  "Location": ["LOC"],
 
37
  "Project": ["MISC"],
38
  "Education": ["MISC"],
39
  }
40
-
41
- # Split the custom labels provided by the user
42
- requested_labels = [label.strip() for label in labels.split(",")]
43
-
44
- # Perform NER on the extracted text
45
- ner_results = ner_pipeline(text)
46
-
47
  # Initialize a dictionary to hold the extracted information
48
  extracted_info = {label: [] for label in requested_labels}
49
-
 
 
 
50
  # Process the NER results
51
  for entity in ner_results:
52
- # Remove subword tokens (##) and map the entity to the custom labels
53
  entity_text = entity['word'].replace("##", "")
54
- mapped_label = map_labels(entity['entity_group'], label_map)
55
-
56
- # If the mapped label is in the requested labels, store the entity
57
- if mapped_label in extracted_info:
58
- extracted_info[mapped_label].append(entity_text)
59
-
 
 
60
  # Format the output
61
  output = ""
62
  for label, entities in extracted_info.items():
63
  if entities:
64
- output += f"{label}: {', '.join(sorted(set(entities)))}\n"
 
 
65
  else:
66
  output += f"{label}: No information found.\n"
67
-
68
  return output.strip()
69
 
70
  # Create Gradio components
71
- file_input = gr.File(label="Upload a PDF file")
72
  label_input = gr.Textbox(label="Enter labels to extract (comma-separated)")
73
  output_text = gr.Textbox(label="Extracted Information")
74
 
@@ -77,9 +118,9 @@ iface = gr.Interface(
77
  fn=process_text,
78
  inputs=[file_input, label_input],
79
  outputs=output_text,
80
- title="NER with Custom Labels from PDF",
81
- description="Upload a PDF file and extract entities based on custom labels."
82
  )
83
 
84
  # Launch the Gradio interface
85
- iface.launch()
 
1
 
2
+
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
5
+ from sentence_transformers import SentenceTransformer, util
6
  import fitz # PyMuPDF for PDF handling
7
+ import torch
8
+ import docx # For DOCX handling
9
 
10
+ # Load pre-trained models
11
  model_name = "dbmdz/bert-large-cased-finetuned-conll03-english"
12
+ ner_model = AutoModelForTokenClassification.from_pretrained(model_name)
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ ner_pipeline = pipeline("ner", model=ner_model, tokenizer=tokenizer, aggregation_strategy="simple")
15
+
16
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
17
 
18
+ # Function to extract text from a PDF file with error handling
19
  def extract_text_from_pdf(file_path):
20
+ try:
21
+ doc = fitz.open(file_path)
22
+ text = ""
23
+ for page in doc:
24
+ text += page.get_text()
25
+ return text.strip()
26
+ except Exception as e:
27
+ return f"Error extracting text from PDF: {str(e)}"
28
+
29
+ # Function to extract text from a DOCX file
30
+ def extract_text_from_docx(file_path):
31
+ try:
32
+ doc = docx.Document(file_path)
33
+ text = "\n".join([para.text for para in doc.paragraphs])
34
+ return text.strip()
35
+ except Exception as e:
36
+ return f"Error extracting text from DOCX: {str(e)}"
37
+
38
+ # Function to calculate cosine similarity
39
+ def calculate_similarity(input_label, predefined_labels):
40
+ input_embedding = embedding_model.encode(input_label, convert_to_tensor=True)
41
+ predefined_embeddings = embedding_model.encode(predefined_labels, convert_to_tensor=True)
42
+ cosine_scores = util.pytorch_cos_sim(input_embedding, predefined_embeddings)
43
+ best_match_idx = torch.argmax(cosine_scores).item()
44
+ return predefined_labels[best_match_idx], cosine_scores[0][best_match_idx].item()
45
+
46
+ # Function to map recognized entities to custom labels with cosine similarity
47
+ def map_labels_with_similarity(input_label, label_map):
48
+ predefined_labels = list(label_map.keys())
49
+ best_match_label, similarity_score = calculate_similarity(input_label, predefined_labels)
50
+ if similarity_score > 0.7: # Threshold for considering a match
51
+ return best_match_label
52
  return None
53
 
54
  # Function to process the text and extract entities based on custom labels
55
  def process_text(file, labels):
56
+ # Determine the file type and extract text accordingly
57
+ if file.name.endswith(".pdf"):
58
+ text = extract_text_from_pdf(file.name)
59
+ elif file.name.endswith(".docx"):
60
+ text = extract_text_from_docx(file.name)
61
+ else:
62
+ return "Unsupported file type. Please upload a PDF or DOCX file."
63
+
64
+ if text.startswith("Error"):
65
+ return text # Return the error message if text extraction failed
66
+
67
  # Define the custom label mapping
68
  label_map = {
69
  "Name": ["PER"],
70
  "Organization": ["ORG"],
71
  "Location": ["LOC"],
72
+ "Address": ["LOC"], # Address mapped to Location
73
  "Project": ["MISC"],
74
  "Education": ["MISC"],
75
  }
76
+
77
+ # Split the custom labels provided by the user and handle potential input issues
78
+ requested_labels = [label.strip().capitalize() for label in labels.split(",") if label.strip()]
79
+ if not requested_labels:
80
+ return "No valid labels provided. Please enter valid labels to extract."
81
+
 
82
  # Initialize a dictionary to hold the extracted information
83
  extracted_info = {label: [] for label in requested_labels}
84
+
85
+ # Perform NER on the extracted text
86
+ ner_results = ner_pipeline(text)
87
+
88
  # Process the NER results
89
  for entity in ner_results:
 
90
  entity_text = entity['word'].replace("##", "")
91
+ entity_group = entity['entity_group']
92
+
93
+ # Determine the best matching label using cosine similarity
94
+ for input_label in requested_labels:
95
+ best_match_label = map_labels_with_similarity(input_label, label_map)
96
+ if best_match_label and entity_group in label_map[best_match_label]:
97
+ extracted_info[input_label].append(entity_text)
98
+
99
  # Format the output
100
  output = ""
101
  for label, entities in extracted_info.items():
102
  if entities:
103
+ # Remove duplicates and clean up the entities
104
+ unique_entities = sorted(set(entities))
105
+ output += f"{label}: {', '.join(unique_entities)}\n"
106
  else:
107
  output += f"{label}: No information found.\n"
108
+
109
  return output.strip()
110
 
111
  # Create Gradio components
112
+ file_input = gr.File(label="Upload a PDF or DOCX file")
113
  label_input = gr.Textbox(label="Enter labels to extract (comma-separated)")
114
  output_text = gr.Textbox(label="Extracted Information")
115
 
 
118
  fn=process_text,
119
  inputs=[file_input, label_input],
120
  outputs=output_text,
121
+ title="NER with Custom Labels from PDF or DOCX",
122
+ description="Upload a PDF or DOCX file and extract entities based on custom labels."
123
  )
124
 
125
  # Launch the Gradio interface
126
+ iface.launch()