gradio-LegalNER / src /legalNER.py
arosyihuddin's picture
update UI
de92ab7
raw
history blame
6.66 kB
from src.helper import *
import gradio as gr
import torch
class LegalNER():
def __init__(self, model, tokenizer, ids_to_labels, check_point='IndoBERT (IndoLEM)', label_all_tokens=True):
self.model = model
self.tokenizer = tokenizer
self.check_point = check_point
self.label_all_tokens = label_all_tokens
self.prediction_label = ''
self.data_token = ''
self.ids_to_labels = ids_to_labels
self.label_extraction = []
self.tokenizer_decode = ''
self.label_convert = {'B_VERN' : 'Nomor Putusan',
'B_DEFN' : 'Nama Terdakwa',
'B_CRIA' : 'Tindak Pidana',
'B_ARTV' : 'Melanggar KUHP',
'B_PENA' : 'Tuntutan Hukum',
'B_PUNI' : 'Putusan Hukum',
'B_TIMV' : 'Tanggal Putusan',
'B_JUDP' : 'Hakim Ketua',
'B_JUDG' : 'Hakim Anggota',
'B_REGI' : 'Panitera',
'B_PROS' : 'Penuntut Umum',
'B_ADVO' : 'Pengacara',
}
def align_word_ids(self, texts):
tokenized_inputs = self.tokenizer(texts, padding='max_length', max_length=512, truncation=True)
word_ids = tokenized_inputs.word_ids()
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100)
elif word_idx != previous_word_idx:
try:
label_ids.append(1)
except:
label_ids.append(-100)
else:
try:
label_ids.append(1 if self.label_all_tokens else -100)
except:
label_ids.append(-100)
previous_word_idx = word_idx
return label_ids
def labelToText(self):
prev_tag = 'O'
result = {}
temp = ''
# Menganggabungkan semua token menjadi satu kalimat sesuai dengan labelnya
for i, word in enumerate(self.data_token):
if self.prediction_label[i] != 'O':
if prev_tag == 'O' and temp != '':
temp = ''
if '##' in word:
temp += word.replace('##', '')
else:
temp += ' ' + word
else:
if temp != "":
result[prev_tag.replace("I_", "B_")] = temp.strip()
temp = ""
prev_tag = self.prediction_label[i]
return result
def dis_pdf_prediction(self):
# Memilih prediksi entitas yang paling bagus
entity_result = {}
for i in self.label_extraction:
if len(list(i.keys())) > 1:
for y in i.items():
if y[0] not in entity_result:
entity_result[y[0]] = y[1]
else:
if len(entity_result[y[0]]) < len(y[1]):
entity_result[y[0]] = y[1]
else:
if tuple(i.items())[0] not in entity_result:
entity_result[tuple(i.items())[0][0]] = tuple(i.items())[0][1]
# Mengkonversi hasil ekstraski entitas dalam bentuk List
result = ''
for i, (label, data) in enumerate(entity_result.items()):
if label in ['B_PENA', 'B_ARTV', 'B_PROS']:
result += f'{i+1}. {self.label_convert[label]}\t = {data.capitalize()}\n'
elif label in ['B_JUDP', 'B_CRIA']:
result += f'{i+1}. {self.label_convert[label]}\t\t\t = {data.capitalize()}\n'
elif label in ['B_ADVO', 'B_REGI']:
result += f'{i+1}. {self.label_convert[label]}\t\t\t\t\t = {data.capitalize()}\n'
else:
result += f'{i+1}. {self.label_convert[label]}\t\t = {data.capitalize()}\n'
return result
def dis_text_prediction(self):
result = []
temp_result = {}
count_huruf = 0
temp_word = ''
temp_label = ''
temp_label = ''
temp_count_huruf = 0
prev_word = ''
for i, (word, label) in enumerate(zip(self.data_token, self.prediction_label)):
if label != 'O':
if temp_word != '' and '##' not in word:
temp_result['entity'] = temp_label
temp_result['word'] = temp_word
temp_result['start'] = temp_count_huruf
temp_result['end'] = temp_count_huruf + (len(temp_word))
result.append(temp_result)
temp_word, temp_label, temp_count_huruf, temp_result = '', '', 0, {}
if '##' in word:
temp_word += word.replace('##', '')
else:
temp_label = label
temp_word = word
temp_count_huruf = count_huruf
if i == len(self.data_token)-1:
temp_result['entity'] = temp_label
temp_result['word'] = temp_word
temp_result['start'] = temp_count_huruf
temp_result['end'] = temp_count_huruf + (len(temp_word))
result.append(temp_result)
temp_word, temp_label, temp_count_huruf, temp_result = '', '', 0, {}
if '##' in word:
count_huruf += len(word)-2
else:
count_huruf += len(word)+1
return result
def fit_transform(self, texts, progress=gr.Progress()):
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
if use_cuda:
self.model = self.model.cuda()
file_check_point = 'model/IndoLEM/model_fold_4.pth' if self.check_point == 'IndoBERT (IndoLEM)' else 'model/IndoNLU/model_fold_4.pth'
model_weights = torch.load(file_check_point, map_location=torch.device(device))
self.model.load_state_dict(model_weights)
for text in progress.tqdm(texts, desc="Ekstraksi Entitas"):
toknize = self.tokenizer(text, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")
input_ids = toknize['input_ids'].to(device)
mask = toknize['attention_mask'].to(device)
logits = self.model(input_ids, mask, None)
label_ids = torch.Tensor(self.align_word_ids(text)).unsqueeze(0).to(device)
logits_clean = logits[0][label_ids != -100]
predictions = logits_clean.argmax(dim=1).tolist()
prediction_label = [self.ids_to_labels[i] for i in predictions]
input_ids_conv = self.tokenizer.convert_ids_to_tokens(toknize['input_ids'][0])
data_token = [word for word in input_ids_conv if word not in ['[CLS]', '[SEP]', '[PAD]']]
self.tokenizer_decode = token_decode(input_ids_conv)
self.data_token = data_token
self.prediction_label = prediction_label
labelConv = self.labelToText()
if labelConv:
self.label_extraction.append(labelConv)
def predict(self, doc):
if '.pdf' not in doc:
self.fit_transform([doc.strip()])
return self.dis_text_prediction()
else:
file_pdf = read_pdf(doc)
sentence_file = file_pdf.split(';')
self.fit_transform(sentence_file)
return self.dis_pdf_prediction()