from src.helper import * import gradio as gr import torch class LegalNER(): def __init__(self, model, tokenizer, ids_to_labels, check_point='IndoBERT (IndoLEM)', label_all_tokens=True): self.model = model self.tokenizer = tokenizer self.check_point = check_point self.label_all_tokens = label_all_tokens self.prediction_label = '' self.data_token = '' self.ids_to_labels = ids_to_labels self.label_extraction = [] self.tokenizer_decode = '' self.label_convert = {'B_VERN' : 'Nomor Putusan', 'B_DEFN' : 'Nama Terdakwa', 'B_CRIA' : 'Tindak Pidana', 'B_ARTV' : 'Melanggar KUHP', 'B_PENA' : 'Tuntutan Hukum', 'B_PUNI' : 'Putusan Hukum', 'B_TIMV' : 'Tanggal Putusan', 'B_JUDP' : 'Hakim Ketua', 'B_JUDG' : 'Hakim Anggota', 'B_REGI' : 'Panitera', 'B_PROS' : 'Penuntut Umum', 'B_ADVO' : 'Pengacara', } def align_word_ids(self, texts): tokenized_inputs = self.tokenizer(texts, padding='max_length', max_length=512, truncation=True) word_ids = tokenized_inputs.word_ids() previous_word_idx = None label_ids = [] for word_idx in word_ids: if word_idx is None: label_ids.append(-100) elif word_idx != previous_word_idx: try: label_ids.append(1) except: label_ids.append(-100) else: try: label_ids.append(1 if self.label_all_tokens else -100) except: label_ids.append(-100) previous_word_idx = word_idx return label_ids def labelToText(self): prev_tag = 'O' result = {} temp = '' # Menganggabungkan semua token menjadi satu kalimat sesuai dengan labelnya for i, word in enumerate(self.data_token): if self.prediction_label[i] != 'O': if prev_tag == 'O' and temp != '': temp = '' if '##' in word: temp += word.replace('##', '') else: temp += ' ' + word else: if temp != "": result[prev_tag.replace("I_", "B_")] = temp.strip() temp = "" prev_tag = self.prediction_label[i] return result def dis_pdf_prediction(self): # Memilih prediksi entitas yang paling bagus entity_result = {} for i in self.label_extraction: if len(list(i.keys())) > 1: for y in i.items(): if y[0] not in entity_result: entity_result[y[0]] = y[1] else: if len(entity_result[y[0]]) < len(y[1]): entity_result[y[0]] = y[1] else: if tuple(i.items())[0] not in entity_result: entity_result[tuple(i.items())[0][0]] = tuple(i.items())[0][1] # Mengkonversi hasil ekstraski entitas dalam bentuk List result = '' for i, (label, data) in enumerate(entity_result.items()): if label in ['B_PENA', 'B_ARTV', 'B_PROS']: result += f'{i+1}. {self.label_convert[label]}\t = {data.capitalize()}\n' elif label in ['B_JUDP', 'B_CRIA']: result += f'{i+1}. {self.label_convert[label]}\t\t\t = {data.capitalize()}\n' elif label in ['B_ADVO', 'B_REGI']: result += f'{i+1}. {self.label_convert[label]}\t\t\t\t\t = {data.capitalize()}\n' else: result += f'{i+1}. {self.label_convert[label]}\t\t = {data.capitalize()}\n' return result def dis_text_prediction(self): result = [] temp_result = {} count_huruf = 0 temp_word = '' temp_label = '' temp_label = '' temp_count_huruf = 0 prev_word = '' for i, (word, label) in enumerate(zip(self.data_token, self.prediction_label)): if label != 'O': if temp_word != '' and '##' not in word: temp_result['entity'] = temp_label temp_result['word'] = temp_word temp_result['start'] = temp_count_huruf temp_result['end'] = temp_count_huruf + (len(temp_word)) result.append(temp_result) temp_word, temp_label, temp_count_huruf, temp_result = '', '', 0, {} if '##' in word: temp_word += word.replace('##', '') else: temp_label = label temp_word = word temp_count_huruf = count_huruf if i == len(self.data_token)-1: temp_result['entity'] = temp_label temp_result['word'] = temp_word temp_result['start'] = temp_count_huruf temp_result['end'] = temp_count_huruf + (len(temp_word)) result.append(temp_result) temp_word, temp_label, temp_count_huruf, temp_result = '', '', 0, {} if '##' in word: count_huruf += len(word)-2 else: count_huruf += len(word)+1 return result def fit_transform(self, texts, progress=gr.Progress()): use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") if use_cuda: self.model = self.model.cuda() file_check_point = 'model/IndoLEM/model_fold_4.pth' if self.check_point == 'IndoBERT (IndoLEM)' else 'model/IndoNLU/model_fold_4.pth' model_weights = torch.load(file_check_point, map_location=torch.device(device)) self.model.load_state_dict(model_weights) for text in progress.tqdm(texts, desc="Ekstraksi Entitas"): toknize = self.tokenizer(text, padding='max_length', max_length = 512, truncation=True, return_tensors="pt") input_ids = toknize['input_ids'].to(device) mask = toknize['attention_mask'].to(device) logits = self.model(input_ids, mask, None) label_ids = torch.Tensor(self.align_word_ids(text)).unsqueeze(0).to(device) logits_clean = logits[0][label_ids != -100] predictions = logits_clean.argmax(dim=1).tolist() prediction_label = [self.ids_to_labels[i] for i in predictions] input_ids_conv = self.tokenizer.convert_ids_to_tokens(toknize['input_ids'][0]) data_token = [word for word in input_ids_conv if word not in ['[CLS]', '[SEP]', '[PAD]']] self.tokenizer_decode = token_decode(input_ids_conv) self.data_token = data_token self.prediction_label = prediction_label labelConv = self.labelToText() if labelConv: self.label_extraction.append(labelConv) def predict(self, doc): if '.pdf' not in doc: self.fit_transform([doc.strip()]) return self.dis_text_prediction() else: file_pdf = read_pdf(doc) sentence_file = file_pdf.split(';') self.fit_transform(sentence_file) return self.dis_pdf_prediction()