File size: 27,523 Bytes
74e11c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 |
import torch
import torch.nn as nn
from functools import partial, cache
from argparse import Namespace
from typing import List, Tuple, Dict, Union, Optional
from itertools import chain
import random
from typing import Literal
from models.graph_T5.graph_t5 import T5Config, T5PreTrainedModel, T5EncoderModel, T5Tokenizer
from models.graph_T5.graph_t5.modeling_t5 import T5Attention
import models.graph_T5.graph_t5.modeling_t5
class Graph():
"""
A graph class.
:param g: A list of tuples, where each tuple is a triple (head, r, tail).
"""
def __init__(
self,
g: List[Tuple[str,str,str]] = []
):
self.g = g
self.concepts = self.get_concepts() # list of all concepts in the graph
self.relations = self.get_relations() # list of all relations in the graph
self.relations_multiple = self.get_relations_multiple() # list of all relations in the graph, including duplicate relations
@property
def g(self) -> List[Tuple[str,str,str]]:
return self._g
@g.setter
def g(self, g: List[Tuple[str,str,str]]):
self._g = g
def num_triplets(self) -> int:
"""
Get the number of triplets in the graph.
"""
return len(self.g)
def get_concepts(self) -> List[str]:
"""
Get the concepts in the graph.
"""
concepts = list(set([triplet[i] for triplet in self.g for i in [0, 2]]))
concepts.sort() # not necessary but makes debugging easier
return concepts
def get_relations(self) -> List[str]:
"""
Get the relations in the graph.
"""
relations = list(set(self.get_relations_multiple()))
relations.sort() # not necessary but makes debugging easier
return relations
def get_relations_multiple(self) -> List[str]:
"""
Get the relations in the graph, including duplicate relations.
"""
relations = [triplet[1] for triplet in self.g]
return relations
def __str__(self):
out_str = '\n'.join([str(triplet) for triplet in self.g])
return out_str
class Data(Namespace):
def __init__(self, **kwargs):
super().__init__()
self.__dict__.update(kwargs)
def get_dummy_graph(num_triplets:int=3) -> Graph:
g = [
("dog", "IsA", "animal"),
("cat", "IsA", "animal"),
("black poodle", "IsA", "dog"),
("black cat", "IsA", "cat"),
]
assert num_triplets <=4, "num_triplets must be <= 4"
g = g[:num_triplets]
g = Graph(g)
return g
def r2nl(r: str) -> str:
"""
Convert a relation to a natural language string. Can be used to implement necessary changes in the data.
"""
return r
def _get_str2tok(g:Graph, tokenizer: T5Tokenizer) -> dict[str, list[int]]:
"""
Get a dictionary that maps strings to tokens.
"""
# tokenize concepts and relations
c_tok = tokenizer([r2nl(c) for c in g.concepts], padding=False)['input_ids']
r_tok = tokenizer([r2nl(r) for r in g.relations], padding=False)['input_ids']
tokens = c_tok + r_tok
node_names = g.concepts + g.relations # these are not necessarily all nodes in the Levi Graph, as relations can occur more than once
assert len(tokens) == len(node_names), f"{len(tokens) = }, {len(node_names) = }"
# remove end-of-sequence token
tokens = [toks[:-1] if toks[-1] == tokenizer.eos_token_id else toks for toks in tokens]
# create a dictionary mapping concepts and relations to their tokenized forms
str2tok = {node: tok for node, tok in zip(node_names, tokens)}
str2tok['</s>'] = [tokenizer.eos_token_id]
return str2tok
def _get_graphT5_input_sequence(g:Graph, str2tok:dict, use_eos:bool) -> Tuple[list, dict]:
# get input sequence (i.e. sequence that will be fed into the model for this graph)
all_nodes = g.relations_multiple + g.concepts # list of all concepts and relations that will be in the final sequence (i.e. all nodes of the Levi Graph) # the order of nodes is first all relations (in the order that they appear in g.g), and then all concepts (in alphabetical order. though here the order is not important)
if use_eos:
all_nodes.append('</s>')
all_tokens = [str2tok[node] for node in all_nodes] # list of length #nodes, where each element is a list of token ids
indices = {node: [] for node in all_nodes} # dictionary mapping each node to its start-index and end- in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts and are as long as the number of occurances of the relation in the graph for relations. # WARNING: this assumes that concepts and realtions have different names. This not always the case for REBEL. For concept_indices this is fixed.
num_relation_tokens = sum([len(token) for token in all_tokens[:len(g.relations_multiple)]]) # number of tokens that are relations
num_concept_tokens = sum([len(token) for token in all_tokens[len(g.relations_multiple):len(g.relations_multiple)+len(g.concepts)]]) # number of tokens that are concepts
num_eos_tokens = 1 if use_eos else 0
is_concept = torch.tensor([False] * num_relation_tokens + [True] * num_concept_tokens + [False] * num_eos_tokens, dtype=torch.bool) # tensor of length #nodes, where each element is True if the node is a concept and False if it is a relation
index_counter = 0
assert len(all_nodes) == len(all_tokens), (all_nodes, all_tokens)
for node, token in zip(all_nodes, all_tokens):
indices[node].append((index_counter, index_counter + len(token)))
# assert is_concept[index_counter:index_counter+len(token)].all() == (node in g.concepts), f"{is_concept = }, {node = }, {g.concepts = }, {index_counter = }, {len(token) = }, {is_concept[index_counter:index_counter+len(token)] = }"
index_counter += len(token)
concept_indices = {node: [indices[node][-1]] for node in g.concepts} # [-1] and reput in list in case relations have the same name as a concept (concepts are put in last).
sequence = torch.tensor(list(chain.from_iterable(all_tokens)), dtype=torch.long)
sequence = sequence.unsqueeze(0) # add batch dimension
is_concept = is_concept.unsqueeze(0) # add batch dimension
return sequence, indices, is_concept, concept_indices
def _get_graphT5_relativeposition_sparsitymask(g:Graph, indices:dict, sequence_length:int, use_eos:bool, eos:str) -> Tuple[torch.Tensor, torch.Tensor]:
### get relative position of each node in the sequence, as well as the sparsity mask ###
# initialize relative position matrix)
relative_position = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.long)
# initialize sparsity mask
sparsity_mask = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.bool)
# initialize use_additional_bucket
use_additional_bucket = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.bool)
# relative positions / sparsity within each node
for start, end in chain.from_iterable(indices.values()):
relative_position[start:end, start:end] = _get_relative_position(end-start)
sparsity_mask[start:end, start:end] = True
# relative position between nodes of the same triplet
relation_counter = {relation: 0 for relation in g.relations} # dictionary mapping each relation to the number of times it has already appeared in the graph
for triplet in g.g:
pos_h = indices[triplet[0]][0] # position of head; tuple (start_index, end_index)
pos_r = indices[triplet[1]][relation_counter[triplet[1]]] # position of relation; tuple (start_index, end_index)
pos_t = indices[triplet[2]][0] # position of tail; tuple (start_index, end_index)
l_h, l_r = pos_h[1] - pos_h[0], pos_r[1] - pos_r[0] # length (i.e. number of tokens) of head and relation
# iterate over all combinations of tokens in each triplet. This implementation is not very elegant, but it is sufficiently fast.
for ih, ph in enumerate(range(pos_h[0], pos_h[1])): # iterate over all head tokens
for ir, pr in enumerate(range(pos_r[0], pos_r[1])): # iterate over all relation tokens
relative_position[ph, pr] = l_h - ih + ir
relative_position[pr, ph] = - (l_h - ih + ir)
sparsity_mask[ph, pr] = True
sparsity_mask[pr, ph] = True
for it, pt in enumerate(range(pos_t[0], pos_t[1])): # iterate over all tail tokens
relative_position[ph, pt] = l_h - ih + l_r + it
relative_position[pt, ph] = - (l_h - ih + l_r + it)
sparsity_mask[ph, pt] = True
sparsity_mask[pt, ph] = True
for ir, pr in enumerate(range(pos_r[0], pos_r[1])): # iterate over all relation tokens
for it, pt in enumerate(range(pos_t[0], pos_t[1])): # iterate over all tail tokens
relative_position[pr, pt] = l_r - ir + it
relative_position[pt, pr] = - (l_r - ir + it)
sparsity_mask[pr, pt] = True
sparsity_mask[pt, pr] = True
relation_counter[triplet[1]] += 1 # next time when that relation comes, then the next tokens will be used
if use_eos:
assert len(indices['</s>']) == 1, f"{indices['</s>'] = } should have length 1"
pos_eos = indices['</s>'][0] # position of head; tuple (start_index, end_index)
assert pos_eos[0] + 1 == pos_eos[1], pos_eos
pos_eos = pos_eos[0] # position of eos token
if eos == 'bidirectional':
relative_position[:, pos_eos] = +1e6
relative_position[pos_eos, :] = -1e6
relative_position[pos_eos, pos_eos] = 0
sparsity_mask[:, pos_eos] = True
sparsity_mask[pos_eos, :] = True
elif eos == 'unidirectional':
relative_position[:, pos_eos] = 1e6
relative_position[pos_eos, pos_eos] = 0
sparsity_mask[pos_eos, :] = False # no messages from eos to other tokens
sparsity_mask[:, pos_eos] = True
else:
raise ValueError(f'{eos = } is not a valid option.')
relative_position = relative_position.unsqueeze(0) # add batch dimension
sparsity_mask = sparsity_mask.unsqueeze(0) # add batch dimension
use_additional_bucket = use_additional_bucket.unsqueeze(0) # add batch dimension
return relative_position, sparsity_mask, use_additional_bucket
def _get_global_graphT5_relativeposition_sparsitymask(g:Graph, indices:dict, sequence_length:int, use_eos:bool, eos:str) -> Tuple[torch.Tensor, torch.Tensor]:
### get relative position of each node in the sequence, as well as the sparsity mask ###
# initialize relative position matrix)
# relative_position = torch.ones(size=(sequence_length, sequence_length), dtype=torch.long) * 1e6 # technically should be float('inf'), but it does not matter
relative_position = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.long)
# initialize sparsity mask
sparsity_mask = torch.ones(size=(sequence_length, sequence_length), dtype=torch.bool) # could switch to None, but then code has to be updated accordingly (in particular get_batch)
# initialize use_additional_bucket
use_additional_bucket = torch.ones(size=(sequence_length, sequence_length), dtype=torch.bool)
# relative positions / sparsity within each node
for start, end in chain.from_iterable(indices.values()):
relative_position[start:end, start:end] = _get_relative_position(end-start)
use_additional_bucket[start:end, start:end] = False
# relative position between nodes of the same triplet
relation_counter = {relation: 0 for relation in g.relations} # dictionary mapping each relation to the number of times it has already appeared in the graph
for triplet in g.g:
pos_h = indices[triplet[0]][0] # position of head; tuple (start_index, end_index)
pos_r = indices[triplet[1]][relation_counter[triplet[1]]] # position of relation; tuple (start_index, end_index)
pos_t = indices[triplet[2]][0] # position of tail; tuple (start_index, end_index)
l_h, l_r = pos_h[1] - pos_h[0], pos_r[1] - pos_r[0] # length (i.e. number of tokens) of head and relation
# iterate over all combinations of tokens in each triplet. This implementation is not very elegant, but it works.
for ih, ph in enumerate(range(pos_h[0], pos_h[1])): # iterate over all head tokens
for ir, pr in enumerate(range(pos_r[0], pos_r[1])): # iterate over all relation tokens
relative_position[ph, pr] = l_h - ih + ir
relative_position[pr, ph] = - (l_h - ih + ir)
use_additional_bucket[ph, pr] = False
use_additional_bucket[pr, ph] = False
for it, pt in enumerate(range(pos_t[0], pos_t[1])): # iterate over all tail tokens
relative_position[ph, pt] = l_h - ih + l_r + it
relative_position[pt, ph] = - (l_h - ih + l_r + it)
use_additional_bucket[ph, pt] = False
use_additional_bucket[pt, ph] = False
for ir, pr in enumerate(range(pos_r[0], pos_r[1])): # iterate over all relation tokens
for it, pt in enumerate(range(pos_t[0], pos_t[1])): # iterate over all tail tokens
relative_position[pr, pt] = l_r - ir + it
relative_position[pt, pr] = - (l_r - ir + it)
use_additional_bucket[pr, pt] = False
use_additional_bucket[pt, pr] = False
relation_counter[triplet[1]] += 1 # next time when that relation comes, then the next tokens will be used
if use_eos:
assert len(indices['</s>']) == 1, f"{indices['</s>'] = } should have length 1"
pos_eos = indices['</s>'][0] # position of head; tuple (start_index, end_index)
assert pos_eos[0] + 1 == pos_eos[1], pos_eos
pos_eos = pos_eos[0] # position of eos token
if eos == 'bidirectional':
relative_position[:, pos_eos] = +1e6
relative_position[pos_eos, :] = -1e6
relative_position[pos_eos, pos_eos] = 0
sparsity_mask[:, pos_eos] = True
sparsity_mask[pos_eos, :] = True
use_additional_bucket[:, pos_eos] = False
use_additional_bucket[pos_eos, :] = False
elif eos == 'unidirectional':
relative_position[:, pos_eos] = 1e6
relative_position[pos_eos, pos_eos] = 0
sparsity_mask[pos_eos, :] = False # no messages from eos to other tokens
sparsity_mask[:, pos_eos] = True
use_additional_bucket[:, pos_eos] = False
use_additional_bucket[pos_eos, :] = False
else:
raise ValueError(f'{eos = } is not a valid option.')
relative_position = relative_position.unsqueeze(0) # add batch dimension
sparsity_mask = sparsity_mask.unsqueeze(0) # add batch dimension
use_additional_bucket = use_additional_bucket.unsqueeze(0) # add batch dimension
return relative_position, sparsity_mask, use_additional_bucket
def graph_to_graphT5(g:Graph, tokenizer:T5Tokenizer, how:str, eos:str)->Data:
"""
Convert a graph to a graphT5 input.
:param g: graph
:param tokenizer: tokenizer
:param how: how to represent the graph. Can be 'local' or 'global' for lGLM and gGLM respectively.
:param eos: end-of-sequence token. Can be `False` for not using an eos token. When using an eos token, there are two ways to use it: `bidirectional` means that the eos token is connected to every other node in the graph, with a relative position of positive infinity (from node to eos) or negative infinity (from eos to node). `unidirectional` means that the eos token is connected to every node in the graph with a relative position of positive infinity (from node to eos), but not the other way around (i.e. no connection from eos to other node). This means, that nodes do not get messages from the eos token, which perceives locality when using the local GLM
"""
if not isinstance(g, Graph):
g = Graph(g)
eos = str(eos)
assert eos in ['False', 'bidirectional', 'unidirectional'], f"{eos = } must be either 'False', 'bidirectional', or 'unidirectional'"
use_eos:bool = eos != 'False'
str2tok = _get_str2tok(g, tokenizer) # get a dictionary mapping concepts and relations to their tokenized forms
sequence, indices, is_concept, concept_indices = _get_graphT5_input_sequence(g, str2tok, use_eos) # get input sequence (i.e. sequence that will be fed into the model for this graph
sequence_length = sequence.shape[1]
if how == 'local':
relative_position, sparsity_mask, use_additional_bucket = _get_graphT5_relativeposition_sparsitymask(g, indices, sequence_length, use_eos, eos)
num_additional_buckets = 0 # lGLM does not use additional buckets
elif how == 'global':
relative_position, sparsity_mask, use_additional_bucket = _get_global_graphT5_relativeposition_sparsitymask(g, indices, sequence_length, use_eos, eos)
num_additional_buckets = 1 # gGLM uses 1 additional bucket for long-ranged G2G connections
else:
raise ValueError(f"how must be either 'local' or 'global', but is {how}")
input_ids = sequence
data = Data(input_ids=input_ids, relative_position=relative_position, sparsity_mask=sparsity_mask, use_additional_bucket=use_additional_bucket, indices=indices, is_concept=is_concept, concept_indices=concept_indices, num_additional_buckets=num_additional_buckets)
return data
@cache
def _get_relative_position(size):
return torch.tensor([[i - j for i in range(size)] for j in range(size)], dtype=torch.long)
def get_embedding(
sequence_embedding: torch.Tensor,
indices: Dict[str, List[Tuple[int, int]]],
concept: str,
embedding_aggregation: str = "mean",
):
"""
Returns the embedding of a concept.
:param sequence_embedding: the embedding of the whole sequence. shape: (sequence_length, embedding_size)
:param indices: dictionary mapping each node to its start-index and end- in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts.
:param concept: the concept for which the embedding should be returned
:param embedding_aggregation: how the embedding of a concept should be aggregated. Either "mean" or "seq". "mean" returns the mean of all tokens of the concept. "seq" returns the embeddings of the all token of the concept.
:return: the aggregated embedding of the concept. shape (1, embedding_size) or (number_of_tokens, embedding_size).
"""
assert concept in indices.keys(), f"{concept = } is not a node in the graph. {indices = }"
assert len(indices[concept]) == 1, f"{concept = } is not a concept, as concepts occur only once in the graph. {indices = }"
start, end = indices[concept][0]
sequence_embedding = sequence_embedding[start:end, :]
if embedding_aggregation == "mean":
return torch.mean(sequence_embedding, dim=0, keepdim=True)
elif embedding_aggregation == "seq":
return sequence_embedding
else:
raise NotImplementedError(f"{embedding_aggregation = } is not supported. Use either 'mean' or 'seq'.")
def add_text_to_graph_data(data, text, tokenizer, use_text):
if use_text in {'False', '', False, None}:
return None
text_seq = torch.tensor(tokenizer(text, padding=False)['input_ids']).unsqueeze(0)
new_input_ids = torch.cat([data.input_ids, text_seq], dim=1)
old_seq_len = data.input_ids.shape[1]
text_seq_len = text_seq.shape[1]
new_seq_len = new_input_ids.shape[1]
new_is_graph = torch.zeros(size=(1, new_seq_len), dtype=torch.bool)
new_is_graph[:, :old_seq_len] = True
if data.relative_position is None: # sequence transformer
assert data.sparsity_mask is None
assert data.use_additional_bucket is None
data.input_ids = new_input_ids
data.is_graph = new_is_graph
return None
new_relative_position = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.relative_position.dtype)
new_relative_position[:, :old_seq_len, :old_seq_len] = data.relative_position
new_relative_position[:, old_seq_len:, old_seq_len:] = _get_relative_position(text_seq_len)
new_sparsity_mask = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.sparsity_mask.dtype)
new_sparsity_mask[:, :old_seq_len, :old_seq_len] = data.sparsity_mask
new_sparsity_mask[:, old_seq_len:, old_seq_len:] = True
new_use_additional_bucket = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.use_additional_bucket.dtype)
new_use_additional_bucket[:, :old_seq_len, :old_seq_len] = data.use_additional_bucket
new_use_additional_bucket[:, old_seq_len:, old_seq_len:] = False # could change that if we want T2T and local G2G relations to be learned separately
if use_text in {'FullyConnected', True}:
new_sparsity_mask[:, old_seq_len:, :old_seq_len] = True
new_sparsity_mask[:, :old_seq_len, old_seq_len:] = True
new_use_additional_bucket[:, old_seq_len:, :old_seq_len] = True
new_use_additional_bucket[:, :old_seq_len, old_seq_len:] = True
new_relative_position[:, old_seq_len:, :old_seq_len] = data.num_additional_buckets
new_relative_position[:, :old_seq_len, old_seq_len:] = data.num_additional_buckets + 1
new_num_additional_buckets = data.num_additional_buckets + 2
else:
raise ValueError(f"unknown use_text {use_text} (type {type(use_text)})")
data.input_ids = new_input_ids
data.relative_position = new_relative_position
data.sparsity_mask = new_sparsity_mask
data.use_additional_bucket = new_use_additional_bucket
data.num_additional_buckets = new_num_additional_buckets
data.is_graph = new_is_graph
return None
class DataProcessor():
@staticmethod
def encode_graph(tokenizer, g:Union[Graph,list[tuple[str,str,str]]], text:Optional[str]=None, how:Literal['global', 'local']='global', eos:str="False")->Data:
"""
convert graph to suitable input for the model.
:param tokenizer: tokenizer
:param g: graph
:param text: text to add to the graph. Can be None if no text should be added.
:param how: how to represent the graph. Can be 'local' or 'global' for lGLM and gGLM respectively.
:param eos: end-of-sequence token. Can be `False` for not using an eos token. This is the method used in the paper. When using an eos token, there are two ways to use it: `bidirectional` means that the eos token is connected to every other node in the graph. `unidirectional` means that the eos token is connected to every node in the graph (from node to eos), but not the other way around (i.e. no connection from eos to other node). This means, that nodes do not get messages from the eos token, which perceives locality when using the local GLM
:return: Data object
"""
if not isinstance(g, Graph):
g = Graph(g)
data = graph_to_graphT5(g, tokenizer, how, eos)
if text is not None:
add_text_to_graph_data(data, text, tokenizer, use_text=True)
return data
@staticmethod
def to_batch(data_instances:list[Data], tokenizer, max_seq_len:Optional[int]=None, device:str='cpu', **kwargs)->dict:
"""
converts list of data instances to batched inputs for GLM forward call.
:param datas: list of Data instances
:param max_seq_len: maximum sequence length
:param tokenizer: tokenizer
:param device: device
:return: dictionary with keys 'input_ids', 'relative_position', 'sparsity_mask', and 'use_additional_bucket'
"""
current_max_seq_len = max([data.input_ids.shape[1] for data in data_instances])
if max_seq_len is None:
max_seq_len = current_max_seq_len
else:
max_seq_len = min(max_seq_len, current_max_seq_len)
if data_instances[0].relative_position is None:
assert data_instances[0].sparsity_mask is None
assert data_instances[0].use_additional_bucket is None
is_sequence_transformer = True
else:
assert data_instances[0].sparsity_mask is not None
assert data_instances[0].use_additional_bucket is not None
is_sequence_transformer = False
# intialize tensors
input_ids = torch.ones((len(data_instances), max_seq_len), dtype=torch.long, device=device) * tokenizer.pad_token_id
if is_sequence_transformer:
relative_position = None
sparsity_mask = None
use_additional_bucket = None
else:
relative_position = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.long, device=device)
sparsity_mask = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
use_additional_bucket = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
# fill tensors
for i, data in enumerate(data_instances):
instance_len = min(data.input_ids.shape[1], max_seq_len)
input_ids[i, :instance_len] = data.input_ids[:, :instance_len]
if not is_sequence_transformer:
relative_position[i, :instance_len, :instance_len] = data.relative_position[:, :instance_len, :instance_len]
sparsity_mask[i, :instance_len, :instance_len] = data.sparsity_mask[:, :instance_len, :instance_len]
use_additional_bucket[i, :instance_len, :instance_len] = data.use_additional_bucket[:, :instance_len, :instance_len]
model_input = {
'input_ids': input_ids,
'relative_position': relative_position,
'sparsity_mask': sparsity_mask,
'use_additional_bucket': use_additional_bucket,
**kwargs
}
return model_input
@staticmethod
def get_embedding(sequence_embedding:torch.Tensor, indices:Dict[str,List[Tuple[int, int]]], concept:str, embedding_aggregation:str="mean"):
"""
Returns embedding of a concept.
:param sequence_embedding: the embedding of the whole sequence. shape: (sequence_length, embedding_size)
:param indices: dictionary mapping each node to its start- and end-index in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts. indices is part of the Data object.
:param concept: the concept for which the embedding should be returned.
:param embedding_aggregation: how the embedding of a concept should be aggregated. Either "mean" or "seq". "mean" returns the mean of all tokens of the concept. "seq" returns the embeddings of the all token of the concept.
:return: the aggregated embedding of the concept. shape (1, embedding_size) or (number_of_tokens, embedding_size).
"""
return get_embedding(sequence_embedding, indices, concept, embedding_aggregation)
|