khrek's picture
Update models.py
2e16cc5 verified
raw
history blame
2.42 kB
import torch
import sentencepiece
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
import os
import spacy
import spacy_transformers
import zipfile
from collections import defaultdict
class Models():
def __init__(self) -> None:
self.load_trained_models()
def load_trained_models(self):
tokenizer = AutoTokenizer.from_pretrained("Jean-Baptiste/camembert-ner-with-dates")
model = AutoModelForTokenClassification.from_pretrained("Jean-Baptiste/camembert-ner-with-dates")
self.ner = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="simple")
current_directory = os.path.dirname(os.path.realpath(__file__))
custom_ner_path = os.path.join(current_directory, 'spacy_model_v2/output/model-best')
destination_folder = "/spacy_model_v2"
if not os.path.exists(custom_ner_path):
with zipfile.ZipFile(r"./spacy_model_v2.zip", 'r') as zip_ref:
# Extract all contents in the current working directory
zip_ref.extractall(current_directory+destination_folder)
self.custom_ner = spacy.load(custom_ner_path)
def extract_ner(self, text):
entities = self.ner(text)
keys = ['DATE', 'ORG', 'LOC']
sort_dict = defaultdict(list)
for entity in entities:
if entity['score'] > 0.75:
sort_dict[entity['entity_group']].append(entity['word'])
filtered_dict = {key: value for key, value in sort_dict.items() if key in keys}
filtered_dict = defaultdict(list, filtered_dict)
return filtered_dict['DATE'], filtered_dict['ORG'], filtered_dict['LOC']
def get_ner(self, text, recover_text):
dates, companies, locations = self.extract_ner(text)
alternative_dates, alternative_companies, alternative_locations = self.extract_ner(recover_text)
if dates == [] :
dates = alternative_dates
if companies == []:
companies = alternative_companies
if locations == []:
locations = alternative_locations
return dates, companies, locations
def get_custom_ner(self, text):
doc = self.custom_ner(text)
entities = list(doc.ents)
sort_dict = defaultdict(list)
for entity in entities:
sort_dict[entity.label_].append(entity.text)
return sort_dict