Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Implements tracking of constraints for a beam item. | |
A list of constraints is given as a list of one or more token | |
sequences, each of length at least one token. For example, for an input sentence | |
> Die maschinelle Übersetzung ist schwer zu kontrollieren. | |
We could have the constraints: | |
* to influence | |
* hard | |
There are two implementations: | |
* OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. | |
* UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. | |
The difference is that in the first, the constraints are assumed to be | |
in order; the algorithm will permit zero or more tokens between them. | |
In the second, the constraints are not ordered, so many orderings will | |
be explored. | |
The same sequence can be present any number of times, and will appear | |
that many times in the output. | |
""" | |
from collections import Counter | |
from typing import List, Optional, Set, Tuple | |
import torch | |
class ConstraintState: | |
def __init__(self): | |
pass | |
def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: | |
"""Takes a list of list of constraints in tensor form (a list of | |
tensor constraints for each sentence) and transforms it into a | |
packed Tensor. For example, here is a batch of size 3 with 3, 0, | |
and 1 constraints: | |
[ [ [3 1 2], [3], [4 5 6 7], ] | |
[], | |
[ [1 8 9 10 1 4 11 12], ] | |
] | |
Its corresponding packed structure is: | |
[ [ 3 3 1 2 0 3 0 4 5 6 7 0], | |
[ 0 0 0 0 0 0 0 0 0 0 0 0], | |
[ 1 1 8 9 10 1 4 11 12 0 0 0] ] | |
The packed tensor has shape (batch size, maxlen), where | |
maxlen is defined below. Each row contains concatenated | |
constraint tokens for that sentence, with 0 appended after | |
each constraint. The first item in each row is the number | |
of constraints for that sentence. So maxlen is the maximum | |
of | |
(number of constraints) + (sum length of constraints) + 1. | |
across all sentences in the batch. | |
""" | |
# The maximum word length of concatenated constraints for any sentence | |
max_constraints_len = 1 | |
for sentence_constraints in batch_constraints: | |
if len(sentence_constraints): | |
# number of constraints, plus sum of constrain lens, plus a zero after each | |
constraints_len = ( | |
1 | |
+ sum([c.size(0) for c in sentence_constraints]) | |
+ len(sentence_constraints) | |
) | |
max_constraints_len = max(max_constraints_len, constraints_len) | |
batch_size = len(batch_constraints) | |
constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() | |
for i, sentence_constraints in enumerate(batch_constraints): | |
constraints_tensor[i, 0] = len(sentence_constraints) | |
offset = 1 | |
for j, constraint in enumerate(sentence_constraints): | |
this_len = constraint.size(0) | |
constraints_tensor[i, offset : offset + this_len] = constraint | |
offset += this_len + 1 | |
return constraints_tensor.long() | |
def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: | |
""" | |
Transforms *one row* of a packed constraint tensor (e.g., for one | |
sentence in the batch) into a list of constraint tensors. | |
""" | |
constraint_list = [] | |
num_constraints = constraint_tensor[0] | |
constraints = constraint_tensor.tolist() | |
offset = 1 | |
for i in range(num_constraints): | |
where = constraints.index(0, offset) | |
constraint_list.append(constraint_tensor[offset:where]) | |
offset = where + 1 | |
return constraint_list | |
class ConstraintNode: | |
""" | |
Represents a node in a trie managing unordered constraints. | |
""" | |
def __init__(self, token: int = None, parent=None): | |
# The token associate with this node (None for the root) | |
self.token = int(token) if token is not None else None | |
# The parent (None at the root) | |
self.parent = parent | |
# Whether this node is a completed constraint | |
self.terminal = 0 | |
# List of child nodes | |
self.children = {} | |
# The cumulative number of constraints from this point in the | |
# trie forward | |
self.num_constraints = 0 | |
def id(self): | |
return self.token | |
def __str__(self): | |
term = self.terminal != 0 | |
return f"[{self.token}].{term}#{self.num_constraints}" | |
def __getitem__(self, key: int): | |
return self.children.get(key, None) | |
def next_tokens(self) -> Set[int]: | |
"""The set of child labels.""" | |
return set(self.children.keys()) | |
def create(constraints: List[List[int]]): | |
root = ConstraintNode() | |
for sequence in constraints: | |
root.add_sequence(sequence) | |
return root | |
def print_graph(node: "ConstraintNode"): | |
if len(node.children) == 0: | |
return str(node) | |
else: | |
s = f"({node}" | |
for child in node.children.values(): | |
s += " " + ConstraintNode.print_graph(child) | |
s += ")" | |
return s | |
def token_counts(self) -> Counter: | |
"""Returns a counter of the number of times each token is used | |
in a constraint. | |
""" | |
token_counts = Counter() | |
kids = list(self.children.values()) | |
while len(kids) > 0: | |
kid = kids.pop() | |
token_counts[kid.id] += kid.num_constraints | |
kids += list(kid.children.values()) | |
return token_counts | |
def tokens(self) -> Set[int]: | |
"""Returns the set of tokens in constraints.""" | |
return set(self.token_counts().keys()) | |
def add_sequence(self, sequence: List[int]): | |
"""Adds a constraint, represented as a list of integers, to | |
the trie.""" | |
assert len(sequence) > 0 | |
token = int(sequence[0]) | |
if token not in self.children: | |
self.children[token] = ConstraintNode(token, parent=self) | |
node = self.children[token] | |
if len(sequence) == 1: | |
node.terminal += 1 | |
node.num_constraints += 1 | |
parent = node.parent | |
while parent is not None: | |
parent.num_constraints += 1 | |
parent = parent.parent | |
else: | |
node.add_sequence(sequence[1:]) | |
class UnorderedConstraintState(ConstraintState): | |
""" | |
Records progress through the set of constraints for each item in the beam | |
using a trie. | |
""" | |
def __init__(self, node: ConstraintNode, copy_from: "ConstraintState" = None): | |
self.node = node | |
if copy_from is None: | |
# The root node | |
self.root = node | |
# The set of states in the graph that have been completed | |
self.completed = Counter() | |
# The... | |
self.generated = Counter() | |
# The list of tokens we need to generate | |
self.needed_tokens = self.root.tokens() | |
else: | |
self.completed = Counter(copy_from.completed) | |
self.generated = Counter(copy_from.generated) | |
self.root = copy_from.root | |
# Mark the node as generated | |
if self.node != self.root: | |
self.generated[node] += 1 | |
def create(constraint_tensor: torch.Tensor): | |
constraint_list = unpack_constraints(constraint_tensor) | |
constraint_trie_root = ConstraintNode.create(constraint_list) | |
return UnorderedConstraintState(constraint_trie_root) | |
def __str__(self): | |
gen_str = ",".join([str(node) for node in self.generated]) | |
return f"{self.name}/{self.bank}({gen_str})x{self.num_completed}" | |
def __copy__(self): | |
copied_state = UnorderedConstraintState(self.node, copy_from=self) | |
return copied_state | |
def copy(self): | |
return self.__copy__() | |
def name(self): | |
if self.node.id is None: | |
return "ROOT" | |
else: | |
return str(self.node.id) | |
def is_root(self): | |
return self.node == self.root | |
def bank(self): | |
return sum(self.generated.values()) | |
def num_completed(self): | |
"""The number of constraints (not constraint tokens) that are completed. | |
In addition to the already-completed states, we need to account for the | |
current state, which might get marked as completed when another token | |
is generated. | |
""" | |
in_final = self.node.terminal and self.completed[self.node] < self.node.terminal | |
return sum(self.completed.values()) + in_final | |
def finished(self): | |
return self.root.num_constraints - self.num_completed == 0 | |
def token_counts(self): | |
return self.root.token_counts() | |
def tokens(self): | |
return self.root.tokens() | |
def num_constraint_tokens(self): | |
return sum(self.token_counts.values()) | |
def next_tokens(self) -> Set[int]: | |
"""Returns the list of tokens that could come next. | |
These are (a) all tokens extending the root state and, for | |
non-root states, additionally all tokens extending the current | |
state.""" | |
if self.node != self.root: | |
return self.root.next_tokens().union(self.node.next_tokens()) | |
else: | |
return self.root.next_tokens() | |
def advance(self, token: int): | |
"""Reads in a token and advances the state. Here's how it works. | |
We can advance to the next state if: | |
- there is a matching child | |
- its path isn't blocked | |
A path is blocked when all constraints that are descendants of | |
that node have already been generated, in the current state. | |
If we are not able to advance from the current state, we "fall | |
off the graph" and return to the root state. There, we again | |
try to advance, checking the same criteria. | |
In any case, when falling off the graph, we need to do some | |
bookkeeping. We: | |
- check whether any constraints were met (all prefixes of | |
current state) | |
- if one is found, mark it as completed | |
- adjust visited nodes accordingly | |
""" | |
token = int(token) | |
next_state = None | |
child = self.node[token] | |
if child is not None and self.generated[child] < child.num_constraints: | |
next_state = UnorderedConstraintState(child, copy_from=self) | |
def rewind(): | |
"""If we're mid-trie and an "illegal" token is chosen next, we need | |
to reset our state to the root state. However, along the way, we need | |
to check whether a prefix of the current trie state represents a state | |
we could mark as completed. | |
""" | |
node = self.node | |
while node != self.root: | |
if node.terminal and self.completed[node] < node.terminal: | |
next_state.completed[node] += 1 | |
return | |
next_state.generated[node] -= 1 | |
node = node.parent | |
# Fall off the graph, check the root | |
if next_state is None and token in self.root.next_tokens(): | |
child = self.root[token] | |
# We can only traverse this edge if it's not saturated | |
if self.generated[child] < child.num_constraints: | |
next_state = UnorderedConstraintState(child, copy_from=self) | |
else: | |
next_state = UnorderedConstraintState(self.root, copy_from=self) | |
# Rewind | |
rewind() | |
elif next_state is None: | |
next_state = UnorderedConstraintState(self.root, copy_from=self) | |
# Rewind | |
rewind() | |
return next_state | |
class ConstraintSequence: | |
def __init__(self, sequences: List[List[int]]): | |
"""Represents a set of possibly multitoken constraints by | |
concatenating them and internally recording the end points. | |
""" | |
self.sequences = [] | |
self.endpoints = [] | |
self.num_tokens = 0 | |
self.tokens = set() | |
for sequence in sequences: | |
for token in sequence: | |
self.tokens.add(token) | |
self.num_tokens += len(sequence) | |
self.endpoints += [False for x in range(len(sequence) - 1)] + [True] | |
self.sequences += sequence | |
def __getitem__(self, key: int): | |
return self.sequences[key] | |
def __len__(self): | |
return len(self.sequences) | |
def __str__(self): | |
return str(self.sequences) | |
class OrderedConstraintState(ConstraintState): | |
""" | |
Records progress through the set of linear nonbranching constraints with gaps. | |
""" | |
def __init__(self, sequence: ConstraintSequence, state: int = -1): | |
self.sequence = sequence | |
self.state = state | |
def create(constraint_tensor: torch.Tensor): | |
constraint_list = unpack_constraints(constraint_tensor) | |
return OrderedConstraintState(ConstraintSequence(constraint_list), -1) | |
def __str__(self): | |
return f"{self.state}/{self.bank}x{self.num_completed}" | |
def __copy__(self): | |
return OrderedConstraintState(self.sequence, self.state) | |
def copy(self): | |
return self.__copy__() | |
def num_completed(self): | |
if self.state == -1: | |
return 0 | |
count = len( | |
list(filter(lambda x: x, self.sequence.endpoints[0 : self.state + 1])) | |
) | |
return count | |
def is_root(self): | |
return self.state == -1 | |
def name(self): | |
if self.state == -1: | |
return "ROOT" | |
else: | |
return str(self.sequence[self.state]) | |
def bank(self) -> int: | |
return self.state + 1 | |
def finished(self): | |
return self.state + 1 == len(self.sequence) | |
def token_counts(self): | |
return self.sequence.token_counts() | |
def tokens(self): | |
return self.sequence.tokens | |
def num_constraint_tokens(self): | |
return sum(self.token_counts.values()) | |
def next_tokens(self) -> Set[int]: | |
"""Returns the list of tokens that could come next. | |
These are (a) all tokens extending the root state and, for | |
non-root states, additionally all tokens extending the current | |
state.""" | |
tokens = set() | |
if self.state > 0: | |
tokens.add(self.sequence[0]) | |
if not self.finished: | |
tokens.add(self.sequence[self.state + 1]) | |
return tokens | |
def advance(self, token: int): | |
"""Reads in a token and advances the state. Here's how it works. | |
We can advance to the next state if: | |
- there is a matching child | |
- its path isn't blocked | |
A path is blocked when all constraints that are descendants of | |
that node have already been generated, in the current state. | |
If we are not able to advance from the current state, we "fall | |
off the graph" and return to the root state. There, we again | |
try to advance, checking the same criteria. | |
In any case, when falling off the graph, we need to do some | |
bookkeeping. We: | |
- check whether any constraints were met (all prefixes of | |
current state) | |
- if one is found, mark it as completed | |
- adjust visited nodes accordingly | |
""" | |
token = int(token) | |
# print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="") | |
if self.finished: | |
# Accept anything | |
next_state = self.copy() | |
elif self.sequence[self.state + 1] == token: | |
# Advance to the next token | |
next_state = OrderedConstraintState(self.sequence, self.state + 1) | |
elif self.sequence.endpoints[self.state]: | |
# Accept anything between constraints (*) | |
next_state = self.copy() | |
elif token == self.sequence[0]: | |
# Start over having generated the first token | |
next_state = OrderedConstraintState(self.sequence, 0) | |
else: | |
# Start over from the root | |
next_state = OrderedConstraintState(self.sequence, -1) | |
return next_state | |