|
from utils.finetune import Graph2TextModule |
|
from typing import Dict, List, Tuple, Union, Optional |
|
import torch |
|
import re |
|
|
|
if torch.cuda.is_available(): |
|
DEVICE = 'cuda' |
|
else: |
|
DEVICE = 'cpu' |
|
print('CUDA NOT AVAILABLE') |
|
|
|
CHECKPOINT = 'base/t5-base_13881_val_avg_bleu=68.1000-step_count=5.ckpt' |
|
MAX_LENGTH = 384 |
|
SEED = 42 |
|
|
|
|
|
class VerbModule(): |
|
|
|
def __init__(self, override_args: Dict[str, str] = None): |
|
|
|
if not override_args: |
|
override_args = {} |
|
self.g2t_module = Graph2TextModule.load_from_checkpoint(CHECKPOINT, strict=False, **override_args) |
|
self.tokenizer = self.g2t_module.tokenizer |
|
|
|
self.vocab = self.tokenizer.get_vocab() |
|
self.convert_some_japanese_characters = True |
|
self.unk_char_replace_sliding_window_size = 2 |
|
self.unknowns = [] |
|
|
|
def __generate_verbalisations_from_inputs(self, inputs: Union[str, List[str]]): |
|
try: |
|
inputs_encoding = self.tokenizer.prepare_seq2seq_batch( |
|
inputs, truncation=True, max_length=MAX_LENGTH, return_tensors='pt' |
|
) |
|
inputs_encoding = {k: v.to(DEVICE) for k, v in inputs_encoding.items()} |
|
|
|
self.g2t_module.model.eval() |
|
with torch.no_grad(): |
|
gen_output = self.g2t_module.model.generate( |
|
inputs_encoding['input_ids'], |
|
attention_mask=inputs_encoding['attention_mask'], |
|
use_cache=True, |
|
decoder_start_token_id = self.g2t_module.decoder_start_token_id, |
|
num_beams= self.g2t_module.eval_beams, |
|
max_length= self.g2t_module.eval_max_length, |
|
length_penalty=1.0 |
|
) |
|
except Exception: |
|
print(inputs) |
|
raise |
|
|
|
return gen_output |
|
|
|
''' |
|
We create this function as an alteration from [this one](https://github.com/huggingface/transformers/blob/198c335d219a5eb4d3f124fdd1ce1a9cd9f78a9b/src/transformers/tokenization_utils_fast.py#L537), mainly because the official 'tokenizer.decode' treats all special tokens the same, while we want to drop all special tokens from the decoded sentence EXCEPT for the <unk> token, which we will replace later on. |
|
''' |
|
def __decode_ids_to_string_custom( |
|
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True |
|
) -> str: |
|
filtered_tokens = self.tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=False) |
|
|
|
|
|
|
|
|
|
|
|
sub_texts = [] |
|
current_sub_text = [] |
|
for token in filtered_tokens: |
|
if skip_special_tokens and\ |
|
token != self.tokenizer.unk_token and\ |
|
token in self.tokenizer.all_special_tokens: |
|
|
|
continue |
|
else: |
|
current_sub_text.append(token) |
|
if current_sub_text: |
|
sub_texts.append(self.tokenizer.convert_tokens_to_string(current_sub_text)) |
|
text = " ".join(sub_texts) |
|
|
|
if clean_up_tokenization_spaces: |
|
clean_text = self.tokenizer.clean_up_tokenization(text) |
|
return clean_text |
|
else: |
|
return text |
|
|
|
def __decode_sentences(self, encoded_sentences: Union[str, List[str]]): |
|
if type(encoded_sentences) == str: |
|
encoded_sentences = [encoded_sentences] |
|
decoded_sentences = [self.__decode_ids_to_string_custom(i, skip_special_tokens=True) for i in encoded_sentences] |
|
return decoded_sentences |
|
|
|
def verbalise_sentence(self, inputs: Union[str, List[str]]): |
|
if type(inputs) == str: |
|
inputs = [inputs] |
|
|
|
gen_output = self.__generate_verbalisations_from_inputs(inputs) |
|
|
|
decoded_sentences = self.__decode_sentences(gen_output) |
|
|
|
if len(decoded_sentences) == 1: |
|
return decoded_sentences[0] |
|
else: |
|
return decoded_sentences |
|
|
|
def verbalise_triples(self, input_triples: Union[Dict[str, str], List[Dict[str, str]], List[List[Dict[str, str]]]]): |
|
if type(input_triples) == dict: |
|
input_triples = [input_triples] |
|
|
|
verbalisation_inputs = [] |
|
for triple in input_triples: |
|
if type(triple) == dict: |
|
assert 'subject' in triple |
|
assert 'predicate' in triple |
|
assert 'object' in triple |
|
verbalisation_inputs.append( |
|
f'translate Graph to English: <H> {triple["subject"]} <R> {triple["predicate"]} <T> {triple["object"]}' |
|
) |
|
elif type(triple) == list: |
|
input_sentence = ['translate Graph to English:'] |
|
for subtriple in triple: |
|
assert 'subject' in subtriple |
|
assert 'predicate' in subtriple |
|
assert 'object' in subtriple |
|
input_sentence.append(f'<H> {subtriple["subject"]}') |
|
input_sentence.append(f'<R> {subtriple["predicate"]}') |
|
input_sentence.append(f'<T> {subtriple["object"]}') |
|
verbalisation_inputs.append( |
|
' '.join(input_sentence) |
|
) |
|
|
|
return self.verbalise_sentence(verbalisation_inputs) |
|
|
|
def verbalise(self, input: Union[str, List, Dict]): |
|
try: |
|
if (type(input) == str) or (type(input) == list and type(input[0]) == str): |
|
return self.verbalise_sentence(input) |
|
elif (type(input) == dict) or (type(input) == list and type(input[0]) == dict): |
|
return self.verbalise_triples(input) |
|
else: |
|
return self.verbalise_triples(input) |
|
except Exception: |
|
print(f'ERROR VERBALISING {input}') |
|
raise |
|
|
|
def add_label_to_unk_replacer(self, label: str): |
|
N = self.unk_char_replace_sliding_window_size |
|
self.unknowns.append({}) |
|
|
|
|
|
if self.convert_some_japanese_characters: |
|
label = label.replace('(','(') |
|
label = label.replace(')',')') |
|
label = label.replace('〈','<') |
|
label = label.replace('/','/') |
|
label = label.replace('〉','>') |
|
|
|
label_encoded = self.tokenizer.encode(label) |
|
label_tokens = self.tokenizer.convert_ids_to_tokens(label_encoded) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
label_tokens = [t for t in label_tokens if t not in [ |
|
self.tokenizer.eos_token, self.tokenizer.pad_token |
|
]] |
|
|
|
label_token_to_string = self.tokenizer.convert_tokens_to_string(label_tokens) |
|
unk_token_to_string = self.tokenizer.convert_tokens_to_string([self.tokenizer.unk_token]) |
|
|
|
|
|
|
|
match_unks_in_label = re.findall('(?:(?: )*<unk>(?: )*)+', label_token_to_string) |
|
if len(match_unks_in_label) > 0: |
|
|
|
if (match_unks_in_label[0]) == label_token_to_string: |
|
|
|
self.unknowns[-1][label_token_to_string.strip()] = label |
|
|
|
else: |
|
|
|
|
|
for idx, token in enumerate(label_tokens): |
|
idx_before = max(0,idx-N) |
|
idx_ahead = min(len(label_tokens), idx+N+1) |
|
|
|
|
|
|
|
if token == self.tokenizer.unk_token: |
|
|
|
|
|
if len(match_unks_in_label) > 1: |
|
|
|
|
|
|
|
|
|
|
|
while self.tokenizer.unk_token in label_tokens[idx+1:idx_ahead]: |
|
idx_before = max(0,idx_before-1) |
|
idx_ahead = min(idx+2, idx_ahead-1) |
|
|
|
|
|
while self.tokenizer.unk_token in label_tokens[idx_before:idx]: |
|
idx_before = min(idx-1,idx_before+2) |
|
|
|
|
|
span = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx_ahead]) |
|
|
|
if idx == 1 and label_tokens[0] == '▁': |
|
|
|
to_replace = '^' + re.escape(span).replace( |
|
re.escape(unk_token_to_string), |
|
'.+?' |
|
) |
|
|
|
replaced_span = re.search( |
|
to_replace, |
|
label |
|
)[0] |
|
self.unknowns[-1][span.strip()] = replaced_span |
|
|
|
elif idx == len(label_tokens)-2 and label_tokens[-1] == self.tokenizer.eos_token: |
|
|
|
pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx]) |
|
pre_idx_unk_counts = pre_idx.count(unk_token_to_string) |
|
to_replace = re.escape(span).replace( |
|
re.escape(unk_token_to_string), |
|
f'[^{re.escape(pre_idx)}]+?' |
|
) + '$' |
|
|
|
if pre_idx.strip() == '': |
|
to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]') |
|
|
|
replaced_span = re.search( |
|
to_replace, |
|
label |
|
)[0] |
|
self.unknowns[-1][span.strip()] = replaced_span |
|
|
|
|
|
else: |
|
|
|
pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx]) |
|
|
|
to_replace = re.escape(span).replace( |
|
re.escape(unk_token_to_string), |
|
f'[^{re.escape(pre_idx)}]+?' |
|
) |
|
|
|
|
|
if pre_idx.strip() == '': |
|
to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]') |
|
|
|
replaced_span = re.search( |
|
to_replace, |
|
label |
|
) |
|
|
|
if replaced_span: |
|
span = re.sub(r'\s([?.!",](?:\s|$))', r'\1', span.strip()) |
|
self.unknowns[-1][span] = replaced_span[0] |
|
|
|
def replace_unks_on_sentence(self, sentence: str, loop_n : int = 3, empty_after : bool = False): |
|
|
|
while '<unk>' in sentence and loop_n > 0: |
|
loop_n -= 1 |
|
for unknowns in self.unknowns: |
|
for k,v in unknowns.items(): |
|
|
|
if k == '<unk>' and loop_n > 0: |
|
continue |
|
|
|
if not k in sentence and k[0] == k[0].lower() and k[0].upper() == sentence[0]: |
|
k = k[0].upper() + k[1:] |
|
v = v[0].upper() + v[1:] |
|
|
|
elif not k in sentence and len(re.findall(r'\s{2,}',k))>0: |
|
k = re.sub(r'\s+', ' ', k) |
|
|
|
sentence = sentence.replace(k.strip(),v.strip(),1) |
|
|
|
|
|
sentence = re.sub(r'\s+', ' ', sentence).strip() |
|
|
|
sentence = re.sub(r'\s([?.!",](?:\s|$))', r'\1', sentence) |
|
if empty_after: |
|
self.unknowns = [] |
|
return sentence |
|
|
|
if __name__ == '__main__': |
|
|
|
verb_module = VerbModule() |
|
verbs = verb_module.verbalise('translate Graph to English: <H> World Trade Center <R> height <T> 200 meter <H> World Trade Center <R> is a <T> tower') |
|
print(verbs) |