File size: 4,519 Bytes
63858e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a283b22
 
63858e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a283b22
63858e7
 
 
a283b22
 
63858e7
 
 
 
 
 
 
 
 
 
 
a283b22
63858e7
fe4a287
 
 
 
 
 
 
 
 
 
 
 
 
a283b22
 
 
63858e7
 
 
 
 
 
a283b22
 
63858e7
 
 
 
 
 
 
 
 
a283b22
63858e7
a283b22
63858e7
 
 
 
a283b22
 
 
63858e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from typing import List, Iterable, Tuple
from functools import partial
import numpy as np
import torch
import json

from utils.token_processing import fix_byte_spaces
from utils.gen_utils import map_nlist


def round_return_value(attentions, ndigits=5):
    """Rounding must happen right before it's passed back to the frontend because there is a little numerical error that's introduced converting back to lists
    
    attentions: {
        'aa': {
            left
            right
            att
        }
    }
    
    """
    rounder = partial(round, ndigits=ndigits)
    nested_rounder = partial(map_nlist, rounder)
    new_out = attentions  # Modify values to save memory
    new_out["aa"]["att"] = nested_rounder(attentions["aa"]["att"])

    return new_out

def flatten_batch(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
    """Remove the batch dimension of every tensor inside the Iterable container `x`"""
    return tuple([x_.squeeze(0) for x_ in x])

def squeeze_contexts(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
    """Combine the last two dimensions of the context."""
    shape = x[0].shape
    new_shape = shape[:-2] + (-1,)
    return tuple([x_.view(new_shape) for x_ in x])

def add_blank(xs: Tuple[torch.tensor]) -> Tuple[torch.Tensor]:
    """The embeddings have n_layers + 1, indicating the final output embedding."""

    return (torch.zeros_like(xs[0]),) + xs

class TransformerOutputFormatter:
    def __init__(
        self,
        sentence: str,
        tokens: List[str],
        special_tokens_mask: List[int],
        att: Tuple[torch.Tensor], 
        topk_words: List[List[str]],
        topk_probs: List[List[float]],
        model_config
    ):
        assert len(tokens) > 0, "Cannot have an empty token output!"

        modified_att = flatten_batch(att)

        self.sentence = sentence
        self.tokens = tokens
        self.special_tokens_mask = special_tokens_mask
        self.attentions = modified_att
        self.topk_words = topk_words
        self.topk_probs = topk_probs
        self.model_config = model_config

        try: 
            # GPT vals
            self.n_layer = self.model_config.n_layer
            self.n_head = self.model_config.n_head
            self.hidden_dim = self.model_config.n_embd
        except AttributeError:
            try: 
                # BERT vals
                self.n_layer = self.model_config.num_hidden_layers
                self.n_head = self.model_config.num_attention_heads
                self.hidden_dim = self.model_config.hidden_size
            except AttributeError: raise


        self.__len = len(tokens)# Get the number of tokens in the input
        assert self.__len == self.attentions[0].shape[-1], "Attentions don't represent the passed tokens!"
    
    def to_json(self, layer:int, ndigits=5):
        """The original API expects the following response:

        aa: {
            att: number[][][]
            left: List[str]
            right: List[str]
        }
        """
        # Convert the embeddings, attentions, and contexts into list. Perform rounding

        rounder = partial(round, ndigits=ndigits)
        nested_rounder = partial(map_nlist, rounder)

        def tolist(tens): return [t.tolist() for t in tens]

        def to_resp(tok: str, topk_words, topk_probs):
            return {
                "text": tok,
                "topk_words": topk_words,
                "topk_probs": nested_rounder(topk_probs)
            }

        side_info = [to_resp(t, w, p) for t,w,p in zip( self.tokens, 
                                                        self.topk_words,
                                                        self.topk_probs)]

        out = {"aa": {
            "att": nested_rounder(tolist(self.attentions[layer])),
            "left": side_info,
            "right": side_info
        }}

        return out

    def display_tokens(self, tokens):
        return fix_byte_spaces(tokens)

    def __repr__(self):
        lim = 50
        if len(self.sentence) > lim: s = self.sentence[:lim - 3] + "..."
        else: s = self.sentence[:lim]

        return f"TransformerOutput({s})"

    def __len__(self):
        return self.__len
        
def to_numpy(x): 
    """Embeddings, contexts, and attentions are stored as torch.Tensors in a tuple. Convert this to a numpy array
    for storage in hdf5"""
    return np.array([x_.detach().numpy() for x_ in x])

def to_searchable(t: Tuple[torch.Tensor]):
    return t.detach().numpy().astype(np.float32)