John Hewitt commited on
Commit
5b1efcc
1 Parent(s): 49ee142

model upload

Browse files
README.md CHANGED
@@ -1,3 +1,6 @@
1
  ---
2
- license: apache-2.0
3
- ---
 
 
 
 
1
  ---
2
+ pipeline_tag: text-generation
3
+ tags:
4
+ - text-generation-inference
5
+ library_name: transformers
6
+ ---
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BackpackGPT2LMHeadModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_backpack_gpt2.BackpackGPT2Config",
7
+ "AutoModelForCausalLM": "modeling_backpack_gpt2.BackpackGPT2LMHeadModel"
8
+ },
9
+ "activation_function": "gelu_new",
10
+ "attn_pdrop": 0.1,
11
+ "bos_token_id": 50256,
12
+ "embd_pdrop": 0.1,
13
+ "eos_token_id": 50256,
14
+ "initializer_range": 0.02,
15
+ "layer_norm_epsilon": 1e-05,
16
+ "model_type": "gpt2",
17
+ "n_embd": 768,
18
+ "n_head": 12,
19
+ "n_inner": null,
20
+ "n_layer": 12,
21
+ "n_positions": 512,
22
+ "num_senses": 16,
23
+ "reorder_and_upcast_attn": false,
24
+ "resid_pdrop": 0.1,
25
+ "scale_attn_by_inverse_layer_idx": true,
26
+ "scale_attn_weights": true,
27
+ "sense_intermediate_scale": 4,
28
+ "summary_activation": null,
29
+ "summary_first_dropout": 0.1,
30
+ "summary_proj_to_labels": true,
31
+ "summary_type": "cls_index",
32
+ "summary_use_proj": true,
33
+ "transformers_version": "4.29.0.dev0",
34
+ "use_cache": true,
35
+ "vocab_size": 50264
36
+ }
configuration_backpack_gpt2.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
2
+
3
+ class BackpackGPT2Config(GPT2Config):
4
+ """
5
+ This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to
6
+ instantiate a Backpack GPT-2 model according to the specified arguments, defining the model architecture.
7
+
8
+ Configuration objects inherit from [`GPT2Config`] and can be used to control the model outputs. Read the
9
+ documentation from [`GPT2Config`] for more information.
10
+
11
+ Args:
12
+ num_senses (`int`, *optional*, defaults to 16):
13
+ The number of sense vectors to define for each word.
14
+ sense_intermediate_scale (`int`, *optional*, defaults ot 4):
15
+ The hidden dimensionality of the sense vector network.
16
+
17
+ Example:
18
+
19
+ ```python
20
+ >>> from transformers import BackpackGPT2Config, BackpackGPT2Model
21
+
22
+ >>> # Initializing a GPT2 configuration
23
+ >>> configuration = BackpackGPT2Config()
24
+
25
+ >>> # Initializing a model (with random weights) from the configuration
26
+ >>> model = BackpackGPT2Model(configuration)
27
+
28
+ >>> # Accessing the model configuration
29
+ >>> configuration = model.config
30
+ """
31
+
32
+ def __init__(self,
33
+ vocab_size=50264,
34
+ num_senses=16,
35
+ sense_intermediate_scale=4,
36
+ n_positions=512,
37
+ scale_attn_by_inverse_layer_idx=True,
38
+ **kwargs,
39
+ ):
40
+ self.num_senses = num_senses
41
+ self.sense_intermediate_scale = sense_intermediate_scale
42
+ super().__init__(vocab_size=vocab_size, n_positions=n_positions, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, **kwargs)
modeling_backpack_gpt2.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+
9
+ from transformers.activations import ACT2FN
10
+ from transformers.pytorch_utils import Conv1D
11
+ from transformers.utils import (
12
+ ModelOutput,
13
+ logging,
14
+ )
15
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2PreTrainedModel
16
+ from .configuration_backpack_gpt2 import BackpackGPT2Config
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ ### Backpack-Specific
22
+ class BackpackGPT2PreTrainedModel(GPT2PreTrainedModel):
23
+ """
24
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
25
+ models.
26
+ """
27
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias"]
28
+
29
+ config_class = BackpackGPT2Config
30
+ base_model_prefix = "backpack"
31
+ is_parallelizable = True
32
+ supports_gradient_checkpointing = False
33
+ _no_split_modules = ["GPT2Block", "BackpackNoMixBlock"]
34
+
35
+ def __init__(self, *inputs, **kwargs):
36
+ super().__init__(*inputs, **kwargs)
37
+
38
+ class BackpackMLP(nn.Module):
39
+
40
+ def __init__(self, embed_dim, intermediate_dim, out_dim, config):
41
+ super().__init__()
42
+ self.c_fc = Conv1D(intermediate_dim, embed_dim)
43
+ self.c_proj = Conv1D(out_dim, intermediate_dim)
44
+ self.act = ACT2FN[config.activation_function]
45
+ self.dropout = nn.Dropout(config.resid_pdrop)
46
+
47
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
48
+ hidden_states = self.c_fc(hidden_states)
49
+ hidden_states = self.act(hidden_states)
50
+ hidden_states = self.c_proj(hidden_states)
51
+ hidden_states = self.dropout(hidden_states)
52
+ return hidden_states
53
+
54
+ class BackpackNoMixBlock(nn.Module):
55
+
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
59
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
60
+ self.mlp = BackpackMLP(config.n_embd, config.n_embd*4, config.n_embd, config)
61
+ self.resid_dropout1 = nn.Dropout(config.resid_pdrop)
62
+ self.resid_dropout2 = nn.Dropout(config.resid_pdrop)
63
+
64
+ def forward(self, hidden_states, residual):
65
+ residual = self.resid_dropout1(hidden_states) + residual
66
+ hidden_states = self.ln_1(residual)
67
+ mlp_out = self.mlp(hidden_states)
68
+ residual = self.resid_dropout2(mlp_out) + residual
69
+ hidden_states = self.ln_2(residual)
70
+ return hidden_states
71
+
72
+
73
+ class BackpackSenseNetwork(nn.Module):
74
+ def __init__(self, config, num_senses, device=None, dtype=None):
75
+ super().__init__()
76
+ self.num_senses = num_senses
77
+ #self.embeddings = embeddings
78
+ self.n_embd = config.n_embd
79
+
80
+ self.dropout = nn.Dropout(config.embd_pdrop)
81
+ self.block = BackpackNoMixBlock(config)
82
+ self.ln = nn.LayerNorm(self.n_embd, eps=config.layer_norm_epsilon)
83
+ self.final_mlp = BackpackMLP(
84
+ embed_dim=config.n_embd,
85
+ intermediate_dim=config.sense_intermediate_scale*config.n_embd,
86
+ out_dim=config.n_embd*config.num_senses,
87
+ config=config,
88
+ )
89
+
90
+ def forward(self, input_embeds):
91
+ residual = self.dropout(input_embeds)
92
+ hidden_states = self.ln(residual)
93
+ hidden_states = self.block(hidden_states, residual)
94
+ senses = self.final_mlp(hidden_states)
95
+ bs, s, nvd = senses.shape
96
+ return senses.reshape(bs, s, self.num_senses, self.n_embd).transpose(1,2) # (bs, nv, s, d)
97
+
98
+ class BackpackWeightNetwork(nn.Module):
99
+
100
+ def __init__(self, num_senses, embed_dim):
101
+ super().__init__()
102
+ self.n_embd = embed_dim
103
+ self.num_senses = num_senses
104
+ self.c_attn = nn.Linear(embed_dim, 2*embed_dim)
105
+ self.softmax_scale = None
106
+
107
+ def forward(self, encoded):
108
+ b, s, d = encoded.shape
109
+ encoded = self.c_attn(encoded) # (b, s, 2*d)
110
+ encoded = encoded.reshape(b, s, 2, self.num_senses, d // self.num_senses) #(b, s, 2, nv, d//nv)
111
+ batch_size, seqlen = encoded.shape[0], encoded.shape[1]
112
+
113
+ # compute scores & mask
114
+ q, k = encoded.unbind(dim=2)
115
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
116
+ scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
117
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
118
+ scores = scores + causal_mask.to(dtype=scores.dtype)
119
+
120
+ return torch.softmax(scores, dim=-1, dtype=q.dtype)
121
+
122
+
123
+ @dataclass
124
+ class BackpackGPT2BaseModelOutput(ModelOutput):
125
+ hidden_states: torch.FloatTensor = None
126
+ contextualization: torch.FloatTensor = None
127
+
128
+ class BackpackGPT2Model(BackpackGPT2PreTrainedModel):
129
+ _keys_to_ignore_on_load_missing = [r".*attn.masked_bias", r".*attn.bias"]
130
+
131
+ def __init__(self, config):
132
+ super().__init__(config)
133
+
134
+ self.embed_dim = config.n_embd
135
+
136
+ self.num_senses = config.num_senses
137
+ self.gpt2_model = GPT2Model(config)
138
+ self.sense_network = BackpackSenseNetwork(config, self.num_senses, self.gpt2_model.wte)
139
+ self.word_embeddings = self.gpt2_model.wte
140
+ self.position_embeddings = self.gpt2_model.wpe
141
+ self.sense_weight_net = BackpackWeightNetwork(self.num_senses, self.embed_dim)
142
+ # Model parallel
143
+ self.model_parallel = False
144
+ self.device_map = None
145
+ self.gradient_checkpointing = False
146
+
147
+ def get_num_senses(self):
148
+ return self.num_senses
149
+
150
+ def get_word_embeddings(self):
151
+ return self.word_embeddings
152
+
153
+ def get_sense_network(self):
154
+ return self.sense_network
155
+
156
+ def forward(self, input_ids, position_ids):
157
+ # Compute senses
158
+ sense_input_embeds = self.word_embeddings(input_ids)
159
+ senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
160
+
161
+ # Compute contextualization weights
162
+ contextl_hidden_states = self.gpt2_model(input_ids, position_ids=position_ids).last_hidden_state # (bs, s, d)
163
+ contextualization = self.sense_weight_net(contextl_hidden_states) # (bs, nv, s, s)
164
+
165
+ # Compute resulting outputs
166
+ hidden_states = torch.sum(contextualization @ senses, dim=1) # (bs, nv, s, d) -> (bs, s, d)
167
+ return BackpackGPT2BaseModelOutput(
168
+ hidden_states=hidden_states,
169
+ contextualization=contextualization,
170
+ )
171
+
172
+ def run_with_custom_contextualization(self, input_ids, contextualization):
173
+ # Compute senses
174
+ sense_input_embeds = self.word_embeddings(input_ids)
175
+ senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
176
+
177
+ # Compute resulting outputs
178
+ hidden_states = torch.sum(contextualization @ senses, dim=1) # (bs, nv, s, d) -> (bs, s, d)
179
+ return BackpackGPT2BaseModelOutput(
180
+ hidden_states=hidden_states,
181
+ contextualization=contextualization,
182
+ )
183
+
184
+ @dataclass
185
+ class BackpackGPT2LMHeadModelOutput(ModelOutput):
186
+ logits: torch.FloatTensor = None
187
+ contextualization: torch.FloatTensor = None
188
+
189
+ class BackpackGPT2LMHeadModel(BackpackGPT2PreTrainedModel):
190
+ _keys_to_ignore_on_load_missing = [r".*attn.masked_bias", r".*attn.bias"]
191
+
192
+ def __init__(self, config):
193
+ super().__init__(config)
194
+ self.backpack = BackpackGPT2Model(config)
195
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
196
+
197
+ # Model parallel
198
+ self.model_parallel = False
199
+ self.device_map = None
200
+
201
+ self.tie_weights()
202
+
203
+ def tie_weights(self):
204
+ self.lm_head.weight = self.backpack.word_embeddings.weight # also tied with the underlying underlying transf
205
+
206
+ def get_lm_head(self):
207
+ return self.lm_head
208
+
209
+ def forward(self, input_ids, position_ids=None):
210
+ outputs = self.backpack(input_ids, position_ids=position_ids)
211
+ hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
212
+ lm_logits = self.lm_head(hidden_states) # (bs, s, V)
213
+ return BackpackGPT2LMHeadModelOutput(
214
+ logits=lm_logits,
215
+ contextualization=contextualization,
216
+ )
217
+
218
+ def run_with_custom_contextualization(self, input_ids, contextualization):
219
+ outputs = self.backpack.run_with_custom_contextualization(input_ids, contextualization)
220
+ hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
221
+ lm_logits = self.lm_head(hidden_states)
222
+ return BackpackGPT2LMHeadModelOutput(
223
+ logits=lm_logits,
224
+ contextualization=contextualization,
225
+ )
226
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c0db4ac7b9af81ea53a1278a708f8fedf02f98c5ef2b70f6453b2110471f27f
3
+ size 683550781