Matthijs commited on
Commit
fc122ab
1 Parent(s): 18d700a
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "architectures": [
4
+ "DistilBertForSequenceClassification"
5
+ ],
6
+ "attention_dropout": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_distilbert_ane.DistilBertConfig",
9
+ "AutoModelForSequenceClassification": "modeling_distilbert_ane.DistilBertForSequenceClassification"
10
+ },
11
+ "dim": 768,
12
+ "dropout": 0.1,
13
+ "finetuning_task": "sst-2",
14
+ "hidden_dim": 3072,
15
+ "id2label": {
16
+ "0": "NEGATIVE",
17
+ "1": "POSITIVE"
18
+ },
19
+ "initializer_range": 0.02,
20
+ "label2id": {
21
+ "NEGATIVE": 0,
22
+ "POSITIVE": 1
23
+ },
24
+ "max_position_embeddings": 512,
25
+ "model_type": "distilbert",
26
+ "n_heads": 12,
27
+ "n_layers": 6,
28
+ "output_past": true,
29
+ "pad_token_id": 0,
30
+ "qa_dropout": 0.1,
31
+ "seq_classif_dropout": 0.2,
32
+ "sinusoidal_pos_embds": false,
33
+ "tie_weights_": true,
34
+ "torch_dtype": "float32",
35
+ "transformers_version": "4.20.0.dev0",
36
+ "vocab_size": 30522
37
+ }
configuration_distilbert_ane.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from transformers.models.distilbert import configuration_distilbert
2
+
3
+ class DistilBertConfig(configuration_distilbert.DistilBertConfig):
4
+ def __init__(self, **kwargs):
5
+ super().__init__(**kwargs)
modeling_distilbert_ane.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022 Apple Inc. All Rights Reserved.
2
+
3
+ # IMPORTANT: This Apple software is supplied to you by Apple
4
+ # Inc. ("Apple") in consideration of your agreement to the following
5
+ # terms, and your use, installation, modification or redistribution of
6
+ # this Apple software constitutes acceptance of these terms. If you do
7
+ # not agree with these terms, please do not use, install, modify or
8
+ # redistribute this Apple software.
9
+
10
+ # In consideration of your agreement to abide by the following terms, and
11
+ # subject to these terms, Apple grants you a personal, non-exclusive
12
+ # license, under Apple's copyrights in this original Apple software (the
13
+ # "Apple Software"), to use, reproduce, modify and redistribute the Apple
14
+ # Software, with or without modifications, in source and/or binary forms;
15
+ # provided that if you redistribute the Apple Software in its entirety and
16
+ # without modifications, you must retain this notice and the following
17
+ # text and disclaimers in all such redistributions of the Apple Software.
18
+ # Neither the name, trademarks, service marks or logos of Apple Inc. may
19
+ # be used to endorse or promote products derived from the Apple Software
20
+ # without specific prior written permission from Apple. Except as
21
+ # expressly stated in this notice, no other rights or licenses, express or
22
+ # implied, are granted by Apple herein, including but not limited to any
23
+ # patent rights that may be infringed by your derivative works or by other
24
+ # works in which the Apple Software may be incorporated.
25
+
26
+ # The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27
+ # MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28
+ # THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29
+ # FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30
+ # OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31
+
32
+ # IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33
+ # OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35
+ # INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36
+ # MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37
+ # AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38
+ # STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39
+ # POSSIBILITY OF SUCH DAMAGE.
40
+
41
+
42
+ import torch
43
+ import torch.nn as nn
44
+
45
+ from transformers.models.distilbert import modeling_distilbert
46
+ from .configuration_distilbert_ane import DistilBertConfig
47
+
48
+ # Note: Original implementation of distilbert uses an epsilon value of 1e-12
49
+ # which is not friendly with the float16 precision that ANE uses by default
50
+ EPS = 1e-7
51
+
52
+ WARN_MSG_FOR_TRAINING_ATTEMPT = \
53
+ "This model is optimized for on-device execution only. " \
54
+ "Please use the original implementation from Hugging Face for training"
55
+
56
+ WARN_MSG_FOR_DICT_RETURN = \
57
+ "coremltools does not support dict outputs. Please set return_dict=False"
58
+
59
+
60
+ class LayerNormANE(nn.Module):
61
+ """ LayerNorm optimized for Apple Neural Engine (ANE) execution
62
+
63
+ Note: This layer only supports normalization over the final dim. It expects `num_channels`
64
+ as an argument and not `normalized_shape` which is used by `torch.nn.LayerNorm`.
65
+ """
66
+
67
+ def __init__(self,
68
+ num_channels,
69
+ clip_mag=None,
70
+ eps=1e-5,
71
+ elementwise_affine=True):
72
+ """
73
+ Args:
74
+ num_channels: Number of channels (C) where the expected input data format is BC1S. S stands for sequence length.
75
+ clip_mag: Optional float value to use for clamping the input range before layer norm is applied.
76
+ If specified, helps reduce risk of overflow.
77
+ eps: Small value to avoid dividing by zero
78
+ elementwise_affine: If true, adds learnable channel-wise shift (bias) and scale (weight) parameters
79
+ """
80
+ super().__init__()
81
+ # Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine)
82
+ self.expected_rank = len('BC1S')
83
+
84
+ self.num_channels = num_channels
85
+ self.eps = eps
86
+ self.clip_mag = clip_mag
87
+ self.elementwise_affine = elementwise_affine
88
+
89
+ if self.elementwise_affine:
90
+ self.weight = nn.Parameter(torch.Tensor(num_channels))
91
+ self.bias = nn.Parameter(torch.Tensor(num_channels))
92
+
93
+ self._reset_parameters()
94
+
95
+ def _reset_parameters(self):
96
+ if self.elementwise_affine:
97
+ nn.init.ones_(self.weight)
98
+ nn.init.zeros_(self.bias)
99
+
100
+ def forward(self, inputs):
101
+ input_rank = len(inputs.size())
102
+
103
+ # Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine)
104
+ # Migrate the data format from BSC to BC1S (most conducive to ANE)
105
+ if input_rank == 3 and inputs.size(2) == self.num_channels:
106
+ inputs = inputs.transpose(1, 2).unsqueeze(2)
107
+ input_rank = len(inputs.size())
108
+
109
+ assert input_rank == self.expected_rank
110
+ assert inputs.size(1) == self.num_channels
111
+
112
+ if self.clip_mag is not None:
113
+ inputs.clamp_(-self.clip_mag, self.clip_mag)
114
+
115
+ channels_mean = inputs.mean(dim=1, keepdims=True)
116
+
117
+ zero_mean = inputs - channels_mean
118
+
119
+ zero_mean_sq = zero_mean * zero_mean
120
+
121
+ denom = (zero_mean_sq.mean(dim=1, keepdims=True) + self.eps).rsqrt()
122
+
123
+ out = zero_mean * denom
124
+
125
+ if self.elementwise_affine:
126
+ out = (out + self.bias.view(1, self.num_channels, 1, 1)
127
+ ) * self.weight.view(1, self.num_channels, 1, 1)
128
+
129
+ return out
130
+
131
+
132
+ class Embeddings(modeling_distilbert.Embeddings):
133
+ """ Embeddings module optimized for Apple Neural Engine
134
+ """
135
+
136
+ def __init__(self, config):
137
+ super().__init__(config)
138
+ setattr(self, 'LayerNorm', LayerNormANE(config.dim, eps=EPS))
139
+
140
+
141
+ class MultiHeadSelfAttention(modeling_distilbert.MultiHeadSelfAttention):
142
+ """ MultiHeadSelfAttention module optimized for Apple Neural Engine
143
+ """
144
+
145
+ def __init__(self, config):
146
+ super().__init__(config)
147
+
148
+ setattr(
149
+ self, 'q_lin',
150
+ nn.Conv2d(
151
+ in_channels=config.dim,
152
+ out_channels=config.dim,
153
+ kernel_size=1,
154
+ ))
155
+
156
+ setattr(
157
+ self, 'k_lin',
158
+ nn.Conv2d(
159
+ in_channels=config.dim,
160
+ out_channels=config.dim,
161
+ kernel_size=1,
162
+ ))
163
+
164
+ setattr(
165
+ self, 'v_lin',
166
+ nn.Conv2d(
167
+ in_channels=config.dim,
168
+ out_channels=config.dim,
169
+ kernel_size=1,
170
+ ))
171
+
172
+ setattr(
173
+ self, 'out_lin',
174
+ nn.Conv2d(
175
+ in_channels=config.dim,
176
+ out_channels=config.dim,
177
+ kernel_size=1,
178
+ ))
179
+
180
+ def prune_heads(self, heads):
181
+ raise NotImplementedError
182
+
183
+ def forward(self,
184
+ query,
185
+ key,
186
+ value,
187
+ mask,
188
+ head_mask=None,
189
+ output_attentions=False):
190
+ """
191
+ Parameters:
192
+ query: torch.tensor(bs, dim, 1, seq_length)
193
+ key: torch.tensor(bs, dim, 1, seq_length)
194
+ value: torch.tensor(bs, dim, 1, seq_length)
195
+ mask: torch.tensor(bs, seq_length) or torch.tensor(bs, seq_length, 1, 1)
196
+
197
+ Returns:
198
+ weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
199
+ dim, 1, seq_length) Contextualized layer. Optional: only if `output_attentions=True`
200
+ """
201
+ # Parse tensor shapes for source and target sequences
202
+ assert len(query.size()) == 4 and len(key.size()) == 4 and len(
203
+ value.size()) == 4
204
+
205
+ bs, dim, dummy, seqlen = query.size()
206
+ # assert seqlen == key.size(3) and seqlen == value.size(3)
207
+ # assert dim == self.dim
208
+ # assert dummy == 1
209
+
210
+ # Project q, k and v
211
+ q = self.q_lin(query)
212
+ k = self.k_lin(key)
213
+ v = self.v_lin(value)
214
+
215
+ # Validate mask
216
+ if mask is not None:
217
+ expected_mask_shape = [bs, seqlen, 1, 1]
218
+ if mask.dtype == torch.bool:
219
+ mask = mask.logical_not().float() * -1e4
220
+ elif mask.dtype == torch.int64:
221
+ mask = (1 - mask).float() * -1e4
222
+ elif mask.dtype != torch.float32:
223
+ raise TypeError(f"Unexpected dtype for mask: {mask.dtype}")
224
+
225
+ if len(mask.size()) == 2:
226
+ mask = mask.unsqueeze(2).unsqueeze(2)
227
+
228
+ if list(mask.size()) != expected_mask_shape:
229
+ raise RuntimeError(
230
+ f"Invalid shape for `mask` (Expected {expected_mask_shape}, got {list(mask.size())}"
231
+ )
232
+
233
+ if head_mask is not None:
234
+ raise NotImplementedError
235
+
236
+ # Compute scaled dot-product attention
237
+ dim_per_head = self.dim // self.n_heads
238
+ mh_q = q.split(
239
+ dim_per_head,
240
+ dim=1) # (bs, dim_per_head, 1, max_seq_length) * n_heads
241
+ mh_k = k.transpose(1, 3).split(
242
+ dim_per_head,
243
+ dim=3) # (bs, max_seq_length, 1, dim_per_head) * n_heads
244
+ mh_v = v.split(
245
+ dim_per_head,
246
+ dim=1) # (bs, dim_per_head, 1, max_seq_length) * n_heads
247
+
248
+ normalize_factor = float(dim_per_head)**-0.5
249
+ attn_weights = [
250
+ torch.einsum('bchq,bkhc->bkhq', [qi, ki]) * normalize_factor
251
+ for qi, ki in zip(mh_q, mh_k)
252
+ ] # (bs, max_seq_length, 1, max_seq_length) * n_heads
253
+
254
+ if mask is not None:
255
+ for head_idx in range(self.n_heads):
256
+ attn_weights[head_idx] = attn_weights[head_idx] + mask
257
+
258
+ attn_weights = [aw.softmax(dim=1) for aw in attn_weights
259
+ ] # (bs, max_seq_length, 1, max_seq_length) * n_heads
260
+ attn = [
261
+ torch.einsum('bkhq,bchk->bchq', wi, vi)
262
+ for wi, vi in zip(attn_weights, mh_v)
263
+ ] # (bs, dim_per_head, 1, max_seq_length) * n_heads
264
+
265
+ attn = torch.cat(attn, dim=1) # (bs, dim, 1, max_seq_length)
266
+
267
+ attn = self.out_lin(attn)
268
+
269
+ if output_attentions:
270
+ return attn, attn_weights.cat(dim=2)
271
+ else:
272
+ return (attn, )
273
+
274
+
275
+ class FFN(modeling_distilbert.FFN):
276
+ """ FFN module optimized for Apple Neural Engine
277
+ """
278
+
279
+ def __init__(self, config):
280
+ super().__init__(config)
281
+ self.seq_len_dim = 3
282
+
283
+ setattr(
284
+ self, 'lin1',
285
+ nn.Conv2d(
286
+ in_channels=config.dim,
287
+ out_channels=config.hidden_dim,
288
+ kernel_size=1,
289
+ ))
290
+
291
+ setattr(
292
+ self, 'lin2',
293
+ nn.Conv2d(
294
+ in_channels=config.hidden_dim,
295
+ out_channels=config.dim,
296
+ kernel_size=1,
297
+ ))
298
+
299
+
300
+ class TransformerBlock(modeling_distilbert.TransformerBlock):
301
+
302
+ def __init__(self, config):
303
+ super().__init__(config)
304
+ setattr(self, 'attention', MultiHeadSelfAttention(config))
305
+ setattr(self, 'sa_layer_norm', LayerNormANE(config.dim, eps=EPS))
306
+ setattr(self, 'ffn', FFN(config))
307
+ setattr(self, 'output_layer_norm', LayerNormANE(config.dim, eps=EPS))
308
+
309
+
310
+ class Transformer(modeling_distilbert.Transformer):
311
+
312
+ def __init__(self, config):
313
+ super().__init__(config)
314
+ setattr(
315
+ self, 'layer',
316
+ nn.ModuleList(
317
+ [TransformerBlock(config) for _ in range(config.n_layers)]))
318
+
319
+
320
+ class DistilBertModel(modeling_distilbert.DistilBertModel):
321
+ config_class = DistilBertConfig
322
+
323
+ def __init__(self, config):
324
+ super().__init__(config)
325
+ setattr(self, 'embeddings', Embeddings(config))
326
+ setattr(self, 'transformer', Transformer(config))
327
+
328
+ # Register hook for unsqueezing nn.Linear parameters to match nn.Conv2d parameter spec
329
+ self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
330
+
331
+ def _prune_heads(self, heads_to_prune):
332
+ raise NotImplementedError
333
+
334
+
335
+ class DistilBertForMaskedLM(modeling_distilbert.DistilBertForMaskedLM):
336
+ config_class = DistilBertConfig
337
+
338
+ def __init__(self, config):
339
+ super().__init__(config)
340
+ from transformers.activations import get_activation
341
+ setattr(self, 'activation', get_activation(config.activation))
342
+ setattr(self, 'distilbert', DistilBertModel(config))
343
+ setattr(self, 'vocab_transform', nn.Conv2d(config.dim, config.dim, 1))
344
+ setattr(self, 'vocab_layer_norm', LayerNormANE(config.dim, eps=EPS))
345
+ setattr(self, 'vocab_projector',
346
+ nn.Conv2d(config.dim, config.vocab_size, 1))
347
+
348
+ def forward(
349
+ self,
350
+ input_ids=None,
351
+ attention_mask=None,
352
+ head_mask=None,
353
+ inputs_embeds=None,
354
+ labels=None,
355
+ output_attentions=None,
356
+ output_hidden_states=None,
357
+ return_dict=None,
358
+ ):
359
+ if self.training or labels is not None:
360
+ raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT)
361
+
362
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
363
+ if return_dict:
364
+ raise ValueError(WARN_MSG_FOR_DICT_RETURN)
365
+
366
+ dlbrt_output = self.distilbert(
367
+ input_ids=input_ids,
368
+ attention_mask=attention_mask,
369
+ head_mask=head_mask,
370
+ inputs_embeds=inputs_embeds,
371
+ output_attentions=output_attentions,
372
+ output_hidden_states=output_hidden_states,
373
+ return_dict=False,
374
+ )
375
+ hidden_states = dlbrt_output[0] # (bs, dim, 1, seq_len)
376
+ prediction_logits = self.vocab_transform(
377
+ hidden_states) # (bs, dim, 1, seq_len)
378
+ prediction_logits = self.activation(
379
+ prediction_logits) # (bs, dim, 1, seq_len)
380
+ prediction_logits = self.vocab_layer_norm(
381
+ prediction_logits) # (bs, dim, 1, seq_len)
382
+ prediction_logits = self.vocab_projector(
383
+ prediction_logits) # (bs, dim, 1, seq_len)
384
+ prediction_logits = prediction_logits.squeeze(-1).squeeze(
385
+ -1) # (bs, dim)
386
+
387
+ output = (prediction_logits, ) + dlbrt_output[1:]
388
+ mlm_loss = None
389
+
390
+ return ((mlm_loss, ) + output) if mlm_loss is not None else output
391
+
392
+
393
+ class DistilBertForSequenceClassification(
394
+ modeling_distilbert.DistilBertForSequenceClassification):
395
+ config_class = DistilBertConfig
396
+
397
+ def __init__(self, config):
398
+ super().__init__(config)
399
+ setattr(self, 'distilbert', DistilBertModel(config))
400
+ setattr(self, 'pre_classifier', nn.Conv2d(config.dim, config.dim, 1))
401
+ setattr(self, 'classifier', nn.Conv2d(config.dim, config.num_labels,
402
+ 1))
403
+
404
+ def forward(
405
+ self,
406
+ input_ids=None,
407
+ attention_mask=None,
408
+ head_mask=None,
409
+ inputs_embeds=None,
410
+ labels=None,
411
+ output_attentions=None,
412
+ output_hidden_states=None,
413
+ return_dict=None,
414
+ ):
415
+ if labels is not None or self.training:
416
+ raise NotImplementedError(WARN_MSG_FOR_TRAINING_ATTEMPT)
417
+
418
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
419
+ if return_dict:
420
+ raise ValueError(WARN_MSG_FOR_DICT_RETURN)
421
+
422
+ distilbert_output = self.distilbert(
423
+ input_ids=input_ids,
424
+ attention_mask=attention_mask,
425
+ head_mask=head_mask,
426
+ inputs_embeds=inputs_embeds,
427
+ output_attentions=output_attentions,
428
+ output_hidden_states=output_hidden_states,
429
+ return_dict=False,
430
+ )
431
+ hidden_state = distilbert_output[0] # (bs, dim, 1, seq_len)
432
+ pooled_output = hidden_state[:, :, :, 0:1] # (bs, dim, 1, 1)
433
+ pooled_output = self.pre_classifier(pooled_output) # (bs, dim, 1, 1)
434
+ pooled_output = nn.ReLU()(pooled_output) # (bs, dim, 1, 1)
435
+ logits = self.classifier(pooled_output) # (bs, num_labels, 1, 1)
436
+ logits = logits.squeeze(-1).squeeze(-1) # (bs, num_labels)
437
+
438
+ output = (logits, ) + distilbert_output[1:]
439
+ loss = None
440
+
441
+ return ((loss, ) + output) if loss is not None else output
442
+
443
+
444
+ class DistilBertForQuestionAnswering(
445
+ modeling_distilbert.DistilBertForQuestionAnswering):
446
+ config_class = DistilBertConfig
447
+
448
+ def __init__(self, config):
449
+ super().__init__(config)
450
+ setattr(self, 'distilbert', DistilBertModel(config))
451
+ setattr(self, 'qa_outputs', nn.Conv2d(config.dim, config.num_labels,
452
+ 1))
453
+
454
+ def forward(
455
+ self,
456
+ input_ids=None,
457
+ attention_mask=None,
458
+ head_mask=None,
459
+ inputs_embeds=None,
460
+ start_positions=None,
461
+ end_positions=None,
462
+ output_attentions=None,
463
+ output_hidden_states=None,
464
+ return_dict=None,
465
+ ):
466
+
467
+ if self.training or start_positions is not None or end_positions is not None:
468
+ raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT)
469
+
470
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
471
+ if return_dict:
472
+ raise ValueError(WARN_MSG_FOR_DICT_RETURN)
473
+
474
+ distilbert_output = self.distilbert(
475
+ input_ids=input_ids,
476
+ attention_mask=attention_mask,
477
+ head_mask=head_mask,
478
+ inputs_embeds=inputs_embeds,
479
+ output_attentions=output_attentions,
480
+ output_hidden_states=output_hidden_states,
481
+ return_dict=False,
482
+ )
483
+ hidden_states = distilbert_output[0] # (bs, dim, 1, max_query_len)
484
+
485
+ hidden_states = self.dropout(
486
+ hidden_states) # (bs, dim, 1, max_query_len)
487
+ logits = self.qa_outputs(hidden_states) # (bs, 2, 1, max_query_len)
488
+ start_logits, end_logits = logits.split(
489
+ 1, dim=1) # (bs, 1, 1, max_query_len) * 2
490
+ start_logits = start_logits.squeeze().contiguous(
491
+ ) # (bs, max_query_len)
492
+ end_logits = end_logits.squeeze().contiguous() # (bs, max_query_len)
493
+
494
+ output = (start_logits, end_logits) + distilbert_output[1:]
495
+ total_loss = None
496
+
497
+ return ((total_loss, ) + output) if total_loss is not None else output
498
+
499
+
500
+ class DistilBertForTokenClassification(
501
+ modeling_distilbert.DistilBertForTokenClassification):
502
+
503
+ def __init__(self, config):
504
+ super().__init__(config)
505
+ setattr(self, 'distilbert', DistilBertModel(config))
506
+ setattr(self, 'classifier',
507
+ nn.Conv2d(config.hidden_size, config.num_labels, 1))
508
+
509
+ def forward(
510
+ self,
511
+ input_ids=None,
512
+ attention_mask=None,
513
+ head_mask=None,
514
+ inputs_embeds=None,
515
+ labels=None,
516
+ output_attentions=None,
517
+ output_hidden_states=None,
518
+ return_dict=None,
519
+ ):
520
+ if self.training or labels is not None:
521
+ raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT)
522
+
523
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
524
+ if return_dict:
525
+ raise ValueError(WARN_MSG_FOR_DICT_RETURN)
526
+
527
+ outputs = self.distilbert(
528
+ input_ids,
529
+ attention_mask=attention_mask,
530
+ head_mask=head_mask,
531
+ inputs_embeds=inputs_embeds,
532
+ output_attentions=output_attentions,
533
+ output_hidden_states=output_hidden_states,
534
+ return_dict=False,
535
+ )
536
+
537
+ sequence_output = outputs[0] # (bs, dim, 1, seq_len)
538
+ logits = self.classifier(
539
+ sequence_output) # (bs, num_labels, 1, seq_len)
540
+ logits = logits.squeeze(2).transpose(1, 2) # (bs, seq_len, num_labels)
541
+
542
+ output = (logits, ) + outputs[1:]
543
+ loss = None
544
+ return ((loss, ) + output) if loss is not None else output
545
+
546
+
547
+ class DistilBertForMultipleChoice(
548
+ modeling_distilbert.DistilBertForMultipleChoice):
549
+ config_class = DistilBertConfig
550
+
551
+ def __init__(self, config):
552
+ super().__init__(config)
553
+ setattr(self, 'distilbert', DistilBertModel(config))
554
+ setattr(self, 'pre_classifier', nn.Conv2d(config.dim, config.dim, 1))
555
+ setattr(self, 'classifier', nn.Conv2d(config.dim, 1, 1))
556
+
557
+ def forward(
558
+ self,
559
+ input_ids=None,
560
+ attention_mask=None,
561
+ head_mask=None,
562
+ inputs_embeds=None,
563
+ labels=None,
564
+ output_attentions=None,
565
+ output_hidden_states=None,
566
+ return_dict=None,
567
+ ):
568
+ if self.training or labels is not None:
569
+ raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT)
570
+
571
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
572
+ if return_dict:
573
+ raise ValueError(WARN_MSG_FOR_DICT_RETURN)
574
+
575
+ num_choices = input_ids.shape[
576
+ 1] if input_ids is not None else inputs_embeds.shape[1]
577
+
578
+ input_ids = input_ids.view(
579
+ -1, input_ids.size(-1)) if input_ids is not None else None
580
+ attention_mask = attention_mask.view(
581
+ -1,
582
+ attention_mask.size(-1)) if attention_mask is not None else None
583
+ inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2),
584
+ inputs_embeds.size(-1))
585
+ if inputs_embeds is not None else None)
586
+
587
+ outputs = self.distilbert(
588
+ input_ids,
589
+ attention_mask=attention_mask,
590
+ head_mask=head_mask,
591
+ inputs_embeds=inputs_embeds,
592
+ output_attentions=output_attentions,
593
+ output_hidden_states=output_hidden_states,
594
+ return_dict=False,
595
+ )
596
+
597
+ hidden_state = outputs[0] # (bs * num_choices, dim, 1, seq_len)
598
+ pooled_output = hidden_state[:, :, :,
599
+ 0:1] # (bs * num_choices, dim, 1, 1)
600
+ pooled_output = self.pre_classifier(
601
+ pooled_output) # (bs * num_choices, dim, 1, 1)
602
+ pooled_output = nn.ReLU()(
603
+ pooled_output) # (bs * num_choices, dim, 1, 1)
604
+ logits = self.classifier(pooled_output) # (bs * num_choices, 1, 1, 1)
605
+ logits = logits.squeeze() # (bs * num_choices)
606
+
607
+ reshaped_logits = logits.view(-1, num_choices) # (bs, num_choices)
608
+
609
+ output = (reshaped_logits, ) + outputs[1:]
610
+ loss = None
611
+
612
+ return ((loss, ) + output) if loss is not None else output
613
+
614
+
615
+ def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
616
+ missing_keys, unexpected_keys, error_msgs):
617
+ """ Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
618
+ """
619
+ for k in state_dict:
620
+ is_internal_proj = all(substr in k for substr in ['lin', '.weight'])
621
+ is_output_proj = all(substr in k
622
+ for substr in ['classifier', '.weight'])
623
+ if is_internal_proj or is_output_proj:
624
+ if len(state_dict[k].shape) == 2:
625
+ state_dict[k] = state_dict[k][:, :, None, None]
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1200fcc3f752c222525b7740abcd87f3aa26a12cd5d5589cf32763458eb9958
3
+ size 267853297