Prove_KCL / utils /verbalisation_module.py
Jongmo's picture
Upload 25 files
a5bbcdb verified
raw
history blame
15 kB
from utils.finetune import Graph2TextModule
from typing import Dict, List, Tuple, Union, Optional
import torch
import re
if torch.cuda.is_available():
DEVICE = 'cuda'
else:
DEVICE = 'cpu'
print('CUDA NOT AVAILABLE')
CHECKPOINT = 'base/t5-base_13881_val_avg_bleu=68.1000-step_count=5.ckpt'
MAX_LENGTH = 384
SEED = 42
class VerbModule():
def __init__(self, override_args: Dict[str, str] = None):
# Model
if not override_args:
override_args = {}
self.g2t_module = Graph2TextModule.load_from_checkpoint(CHECKPOINT, strict=False, **override_args)
self.tokenizer = self.g2t_module.tokenizer
# Unk replacer
self.vocab = self.tokenizer.get_vocab()
self.convert_some_japanese_characters = True
self.unk_char_replace_sliding_window_size = 2
self.unknowns = []
def __generate_verbalisations_from_inputs(self, inputs: Union[str, List[str]]):
try:
inputs_encoding = self.tokenizer.prepare_seq2seq_batch(
inputs, truncation=True, max_length=MAX_LENGTH, return_tensors='pt'
)
inputs_encoding = {k: v.to(DEVICE) for k, v in inputs_encoding.items()}
self.g2t_module.model.eval()
with torch.no_grad():
gen_output = self.g2t_module.model.generate(
inputs_encoding['input_ids'],
attention_mask=inputs_encoding['attention_mask'],
use_cache=True,
decoder_start_token_id = self.g2t_module.decoder_start_token_id,
num_beams= self.g2t_module.eval_beams,
max_length= self.g2t_module.eval_max_length,
length_penalty=1.0
)
except Exception:
print(inputs)
raise
return gen_output
'''
We create this function as an alteration from [this one](https://github.com/huggingface/transformers/blob/198c335d219a5eb4d3f124fdd1ce1a9cd9f78a9b/src/transformers/tokenization_utils_fast.py#L537), mainly because the official 'tokenizer.decode' treats all special tokens the same, while we want to drop all special tokens from the decoded sentence EXCEPT for the <unk> token, which we will replace later on.
'''
def __decode_ids_to_string_custom(
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
) -> str:
filtered_tokens = self.tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=False)
# Do not remove special tokens yet
# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separatly for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and\
token != self.tokenizer.unk_token and\
token in self.tokenizer.all_special_tokens:
continue
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.tokenizer.convert_tokens_to_string(current_sub_text))
text = " ".join(sub_texts)
if clean_up_tokenization_spaces:
clean_text = self.tokenizer.clean_up_tokenization(text)
return clean_text
else:
return text
def __decode_sentences(self, encoded_sentences: Union[str, List[str]]):
if type(encoded_sentences) == str:
encoded_sentences = [encoded_sentences]
decoded_sentences = [self.__decode_ids_to_string_custom(i, skip_special_tokens=True) for i in encoded_sentences]
return decoded_sentences
def verbalise_sentence(self, inputs: Union[str, List[str]]):
if type(inputs) == str:
inputs = [inputs]
gen_output = self.__generate_verbalisations_from_inputs(inputs)
decoded_sentences = self.__decode_sentences(gen_output)
if len(decoded_sentences) == 1:
return decoded_sentences[0]
else:
return decoded_sentences
def verbalise_triples(self, input_triples: Union[Dict[str, str], List[Dict[str, str]], List[List[Dict[str, str]]]]):
if type(input_triples) == dict:
input_triples = [input_triples]
verbalisation_inputs = []
for triple in input_triples:
if type(triple) == dict:
assert 'subject' in triple
assert 'predicate' in triple
assert 'object' in triple
verbalisation_inputs.append(
f'translate Graph to English: <H> {triple["subject"]} <R> {triple["predicate"]} <T> {triple["object"]}'
)
elif type(triple) == list:
input_sentence = ['translate Graph to English:']
for subtriple in triple:
assert 'subject' in subtriple
assert 'predicate' in subtriple
assert 'object' in subtriple
input_sentence.append(f'<H> {subtriple["subject"]}')
input_sentence.append(f'<R> {subtriple["predicate"]}')
input_sentence.append(f'<T> {subtriple["object"]}')
verbalisation_inputs.append(
' '.join(input_sentence)
)
return self.verbalise_sentence(verbalisation_inputs)
def verbalise(self, input: Union[str, List, Dict]):
try:
if (type(input) == str) or (type(input) == list and type(input[0]) == str):
return self.verbalise_sentence(input)
elif (type(input) == dict) or (type(input) == list and type(input[0]) == dict):
return self.verbalise_triples(input)
else:
return self.verbalise_triples(input)
except Exception:
print(f'ERROR VERBALISING {input}')
raise
def add_label_to_unk_replacer(self, label: str):
N = self.unk_char_replace_sliding_window_size
self.unknowns.append({})
# Some pre-processing of labels to normalise some characters
if self.convert_some_japanese_characters:
label = label.replace('(','(')
label = label.replace(')',')')
label = label.replace('〈','<')
label = label.replace('/','/')
label = label.replace('〉','>')
label_encoded = self.tokenizer.encode(label)
label_tokens = self.tokenizer.convert_ids_to_tokens(label_encoded)
# Here, we also remove </s> (eos) and <pad> tokens in the replacing key, because:
# 1) When the whole label is all unk:
# label_token_to_string would be '<unk></s>', meaning the replacing key (which is the same) only replaces
# the <unk> if it appears at the end of the sentence, which is not the desired effect.
# But since this means ANY <unk> will be replaced by this, it would be good to only replace keys that are <unk>
# on the last replacing pass.
# 2) On other cases, then the unk is in the label but not in its entirety, like in the start/end, it might
# involve the starting <pad> token or the ending <eos> token on the replacing key, again forcing the replacement
# to only happen if the label appears in the end of the sentence.
label_tokens = [t for t in label_tokens if t not in [
self.tokenizer.eos_token, self.tokenizer.pad_token
]]
label_token_to_string = self.tokenizer.convert_tokens_to_string(label_tokens)
unk_token_to_string = self.tokenizer.convert_tokens_to_string([self.tokenizer.unk_token])
#print(label_encoded,label_tokens,label_token_to_string)
match_unks_in_label = re.findall('(?:(?: )*<unk>(?: )*)+', label_token_to_string)
if len(match_unks_in_label) > 0:
# If the whole label is made of UNK
if (match_unks_in_label[0]) == label_token_to_string:
#print('Label is all unks')
self.unknowns[-1][label_token_to_string.strip()] = label
# Else, there should be non-UNK characters in the label
else:
#print('Label is NOT all unks')
# Analyse the label with a sliding window of size N (N before, N ahead)
for idx, token in enumerate(label_tokens):
idx_before = max(0,idx-N)
idx_ahead = min(len(label_tokens), idx+N+1)
# Found a UNK
if token == self.tokenizer.unk_token:
# In case multiple UNK, exclude UNKs seen after this one, expand window to other side if possible
if len(match_unks_in_label) > 1:
#print(idx)
#print(label_tokens)
#print(label_tokens[idx_before:idx_ahead])
#print('HERE!')
# Reduce on the right, expanding on the left
while self.tokenizer.unk_token in label_tokens[idx+1:idx_ahead]:
idx_before = max(0,idx_before-1)
idx_ahead = min(idx+2, idx_ahead-1)
#print(label_tokens[idx_before:idx_ahead])
# Now just reduce on the left
while self.tokenizer.unk_token in label_tokens[idx_before:idx]:
idx_before = min(idx-1,idx_before+2)
#print(label_tokens[idx_before:idx_ahead])
span = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx_ahead])
# First token of the label is UNK
if idx == 1 and label_tokens[0] == '▁':
#print('Label begins with unks')
to_replace = '^' + re.escape(span).replace(
re.escape(unk_token_to_string),
'.+?'
)
replaced_span = re.search(
to_replace,
label
)[0]
self.unknowns[-1][span.strip()] = replaced_span
# Last token of the label is UNK
elif idx == len(label_tokens)-2 and label_tokens[-1] == self.tokenizer.eos_token:
#print('Label ends with unks')
pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx])
pre_idx_unk_counts = pre_idx.count(unk_token_to_string)
to_replace = re.escape(span).replace(
re.escape(unk_token_to_string),
f'[^{re.escape(pre_idx)}]+?'
) + '$'
if pre_idx.strip() == '':
to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]')
replaced_span = re.search(
to_replace,
label
)[0]
self.unknowns[-1][span.strip()] = replaced_span
# A token in-between the label is UNK
else:
#print('Label has unks in the middle')
pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx])
to_replace = re.escape(span).replace(
re.escape(unk_token_to_string),
f'[^{re.escape(pre_idx)}]+?'
)
#If there is nothing behind the ??, because it is in the middle but the previous token is also
#a ??, then we would end up with to_replace beginning with [^], which we can't have
if pre_idx.strip() == '':
to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]')
replaced_span = re.search(
to_replace,
label
)
if replaced_span:
span = re.sub(r'\s([?.!",](?:\s|$))', r'\1', span.strip())
self.unknowns[-1][span] = replaced_span[0]
def replace_unks_on_sentence(self, sentence: str, loop_n : int = 3, empty_after : bool = False):
# Loop through in case the labels are repeated, maximum of three times
while '<unk>' in sentence and loop_n > 0:
loop_n -= 1
for unknowns in self.unknowns:
for k,v in unknowns.items():
# Leave to replace all-unk labels at the last pass
if k == '<unk>' and loop_n > 0:
continue
# In case it is because the first letter of the sentence has been uppercased
if not k in sentence and k[0] == k[0].lower() and k[0].upper() == sentence[0]:
k = k[0].upper() + k[1:]
v = v[0].upper() + v[1:]
# In case it is because a double space is found where it should not be
elif not k in sentence and len(re.findall(r'\s{2,}',k))>0:
k = re.sub(r'\s+', ' ', k)
#print(k,'/',v,'/',sentence)
sentence = sentence.replace(k.strip(),v.strip(),1)
#sentence = re.sub(k, v, sentence)
# Removing final doublespaces
sentence = re.sub(r'\s+', ' ', sentence).strip()
# Removing spaces before punctuation
sentence = re.sub(r'\s([?.!",](?:\s|$))', r'\1', sentence)
if empty_after:
self.unknowns = []
return sentence
if __name__ == '__main__':
verb_module = VerbModule()
verbs = verb_module.verbalise('translate Graph to English: <H> World Trade Center <R> height <T> 200 meter <H> World Trade Center <R> is a <T> tower')
print(verbs)