Spaces:
Sleeping
Sleeping
from src.helper import * | |
import gradio as gr | |
import torch | |
class LegalNER(): | |
def __init__(self, model, tokenizer, ids_to_labels, check_point='IndoBERT (IndoLEM)', label_all_tokens=True): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.check_point = check_point | |
self.label_all_tokens = label_all_tokens | |
self.prediction_label = '' | |
self.data_token = '' | |
self.ids_to_labels = ids_to_labels | |
self.label_extraction = [] | |
self.tokenizer_decode = '' | |
self.label_convert = {'B_VERN' : 'Nomor Putusan', | |
'B_DEFN' : 'Nama Terdakwa', | |
'B_CRIA' : 'Tindak Pidana', | |
'B_ARTV' : 'Melanggar KUHP', | |
'B_PENA' : 'Tuntutan Hukum', | |
'B_PUNI' : 'Putusan Hukum', | |
'B_TIMV' : 'Tanggal Putusan', | |
'B_JUDP' : 'Hakim Ketua', | |
'B_JUDG' : 'Hakim Anggota', | |
'B_REGI' : 'Panitera', | |
'B_PROS' : 'Penuntut Umum', | |
'B_ADVO' : 'Pengacara', | |
} | |
def align_word_ids(self, texts): | |
tokenized_inputs = self.tokenizer(texts, padding='max_length', max_length=512, truncation=True) | |
word_ids = tokenized_inputs.word_ids() | |
previous_word_idx = None | |
label_ids = [] | |
for word_idx in word_ids: | |
if word_idx is None: | |
label_ids.append(-100) | |
elif word_idx != previous_word_idx: | |
try: | |
label_ids.append(1) | |
except: | |
label_ids.append(-100) | |
else: | |
try: | |
label_ids.append(1 if self.label_all_tokens else -100) | |
except: | |
label_ids.append(-100) | |
previous_word_idx = word_idx | |
return label_ids | |
def labelToText(self): | |
prev_tag = 'O' | |
result = {} | |
temp = '' | |
# Menganggabungkan semua token menjadi satu kalimat sesuai dengan labelnya | |
for i, word in enumerate(self.data_token): | |
if self.prediction_label[i] != 'O': | |
if prev_tag == 'O' and temp != '': | |
temp = '' | |
if '##' in word: | |
temp += word.replace('##', '') | |
else: | |
temp += ' ' + word | |
else: | |
if temp != "": | |
result[prev_tag.replace("I_", "B_")] = temp.strip() | |
temp = "" | |
prev_tag = self.prediction_label[i] | |
return result | |
def dis_pdf_prediction(self): | |
# Memilih prediksi entitas yang paling bagus | |
entity_result = {} | |
for i in self.label_extraction: | |
if len(list(i.keys())) > 1: | |
for y in i.items(): | |
if y[0] not in entity_result: | |
entity_result[y[0]] = y[1] | |
else: | |
if len(entity_result[y[0]]) < len(y[1]): | |
entity_result[y[0]] = y[1] | |
else: | |
if tuple(i.items())[0] not in entity_result: | |
entity_result[tuple(i.items())[0][0]] = tuple(i.items())[0][1] | |
# Mengkonversi hasil ekstraski entitas dalam bentuk List | |
result = '' | |
for i, (label, data) in enumerate(entity_result.items()): | |
if label in ['B_PENA', 'B_ARTV', 'B_PROS']: | |
result += f'{i+1}. {self.label_convert[label]}\t = {data.capitalize()}\n' | |
elif label in ['B_JUDP', 'B_CRIA']: | |
result += f'{i+1}. {self.label_convert[label]}\t\t\t = {data.capitalize()}\n' | |
elif label in ['B_ADVO', 'B_REGI']: | |
result += f'{i+1}. {self.label_convert[label]}\t\t\t\t\t = {data.capitalize()}\n' | |
else: | |
result += f'{i+1}. {self.label_convert[label]}\t\t = {data.capitalize()}\n' | |
return result | |
def dis_text_prediction(self): | |
result = [] | |
temp_result = {} | |
count_huruf = 0 | |
temp_word = '' | |
temp_label = '' | |
temp_label = '' | |
temp_count_huruf = 0 | |
prev_word = '' | |
for i, (word, label) in enumerate(zip(self.data_token, self.prediction_label)): | |
if label != 'O': | |
if temp_word != '' and '##' not in word: | |
temp_result['entity'] = temp_label | |
temp_result['word'] = temp_word | |
temp_result['start'] = temp_count_huruf | |
temp_result['end'] = temp_count_huruf + (len(temp_word)) | |
result.append(temp_result) | |
temp_word, temp_label, temp_count_huruf, temp_result = '', '', 0, {} | |
if '##' in word: | |
temp_word += word.replace('##', '') | |
else: | |
temp_label = label | |
temp_word = word | |
temp_count_huruf = count_huruf | |
if i == len(self.data_token)-1: | |
temp_result['entity'] = temp_label | |
temp_result['word'] = temp_word | |
temp_result['start'] = temp_count_huruf | |
temp_result['end'] = temp_count_huruf + (len(temp_word)) | |
result.append(temp_result) | |
temp_word, temp_label, temp_count_huruf, temp_result = '', '', 0, {} | |
if '##' in word: | |
count_huruf += len(word)-2 | |
else: | |
count_huruf += len(word)+1 | |
return result | |
def fit_transform(self, texts, progress=gr.Progress()): | |
use_cuda = torch.cuda.is_available() | |
device = torch.device("cuda" if use_cuda else "cpu") | |
if use_cuda: | |
self.model = self.model.cuda() | |
file_check_point = 'model/IndoLEM/model_fold_4.pth' if self.check_point == 'IndoBERT (IndoLEM)' else 'model/IndoNLU/model_fold_4.pth' | |
model_weights = torch.load(file_check_point, map_location=torch.device(device)) | |
self.model.load_state_dict(model_weights) | |
for text in progress.tqdm(texts, desc="Ekstraksi Entitas"): | |
toknize = self.tokenizer(text, padding='max_length', max_length = 512, truncation=True, return_tensors="pt") | |
input_ids = toknize['input_ids'].to(device) | |
mask = toknize['attention_mask'].to(device) | |
logits = self.model(input_ids, mask, None) | |
label_ids = torch.Tensor(self.align_word_ids(text)).unsqueeze(0).to(device) | |
logits_clean = logits[0][label_ids != -100] | |
predictions = logits_clean.argmax(dim=1).tolist() | |
prediction_label = [self.ids_to_labels[i] for i in predictions] | |
input_ids_conv = self.tokenizer.convert_ids_to_tokens(toknize['input_ids'][0]) | |
data_token = [word for word in input_ids_conv if word not in ['[CLS]', '[SEP]', '[PAD]']] | |
self.tokenizer_decode = token_decode(input_ids_conv) | |
self.data_token = data_token | |
self.prediction_label = prediction_label | |
labelConv = self.labelToText() | |
if labelConv: | |
self.label_extraction.append(labelConv) | |
def predict(self, doc): | |
if '.pdf' not in doc: | |
self.fit_transform([doc.strip()]) | |
return self.dis_text_prediction() | |
else: | |
file_pdf = read_pdf(doc) | |
sentence_file = file_pdf.split(';') | |
self.fit_transform(sentence_file) | |
return self.dis_pdf_prediction() |