Gong Baitao commited on
Commit
32554d7
1 Parent(s): ff17c45

Update modeling_cpmbee.py

Browse files
Files changed (1) hide show
  1. modeling_cpmbee.py +244 -4
modeling_cpmbee.py CHANGED
@@ -451,7 +451,7 @@ class CpmBeeEncoder(nn.Module):
451
  hidden_states, attn_weights, current_key_value = layer_outputs
452
  if output_attentions:
453
  all_self_attns += (attn_weights,)
454
- if current_key_value is not None:
455
  current_key_values = current_key_values + (current_key_value,)
456
 
457
  hidden_states = self.output_layernorm(hidden_states)
@@ -734,6 +734,125 @@ class CpmBeeModel(CpmBeePreTrainedModel):
734
  config_class=_CONFIG_FOR_DOC,
735
  )
736
  def forward(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737
  self,
738
  input_ids: torch.Tensor,
739
  input_id_sub: Optional[torch.Tensor] = None,
@@ -1140,6 +1259,127 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
1140
  config_class=_CONFIG_FOR_DOC,
1141
  )
1142
  def forward(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1143
  self,
1144
  input_ids: Optional[torch.Tensor] = None,
1145
  input_id_sub: Optional[torch.Tensor] = None,
@@ -1234,7 +1474,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
1234
  """
1235
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1236
 
1237
- model_output = self.cpmbee(
1238
  input_ids,
1239
  input_id_sub,
1240
  position,
@@ -1533,7 +1773,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
1533
  # init inference
1534
  model_inputs, input_ids = self.prepare_inputs_for_generation(input_ids, batch_size, **model_kwargs)
1535
  pred_start_index = input_ids.size(-1)
1536
- outputs = self(
1537
  **model_inputs,
1538
  return_dict=True,
1539
  output_attentions=output_attentions,
@@ -1578,7 +1818,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
1578
  input_ids, batch_size, beam_scorer, **model_kwargs
1579
  )
1580
 
1581
- outputs = self(
1582
  **model_inputs,
1583
  return_dict=True,
1584
  output_attentions=output_attentions,
 
451
  hidden_states, attn_weights, current_key_value = layer_outputs
452
  if output_attentions:
453
  all_self_attns += (attn_weights,)
454
+ if current_key_values is not None:
455
  current_key_values = current_key_values + (current_key_value,)
456
 
457
  hidden_states = self.output_layernorm(hidden_states)
 
734
  config_class=_CONFIG_FOR_DOC,
735
  )
736
  def forward(
737
+ self,
738
+ input_ids: torch.Tensor,
739
+ input_id_sub: Optional[torch.Tensor] = None,
740
+ length: Optional[torch.Tensor] = None,
741
+ context: Optional[torch.Tensor] = None,
742
+ sample_ids: Optional[torch.Tensor] = None,
743
+ num_segments: Optional[torch.Tensor] = None,
744
+ segment: Optional[torch.Tensor] = None,
745
+ segment_rel_offset: Optional[torch.Tensor] = None,
746
+ segment_rel: Optional[torch.Tensor] = None,
747
+ span: Optional[Dict] = None,
748
+ output_attentions: Optional[bool] = None,
749
+ output_hidden_states: Optional[bool] = None,
750
+ past_key_values: Optional[List] = None,
751
+ use_cache: Optional[bool] = None,
752
+ return_dict: Optional[bool] = None,
753
+ **kwargs,
754
+ ):
755
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
756
+ output_hidden_states = (
757
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
758
+ )
759
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
760
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
761
+
762
+ # dummy setting for common tests
763
+ if input_id_sub is None:
764
+ dtype, device = input_ids.dtype, input_ids.device
765
+ batch, seq_length = input_ids.size()
766
+ segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device)
767
+ context = torch.full((batch, seq_length), 1, dtype=dtype, device=device)
768
+ position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
769
+ input_id_sub = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
770
+ segment_rel_offset = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
771
+ segment_rel = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
772
+ num_segments = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
773
+ sample_ids = torch.zeros_like(input_ids)
774
+
775
+ with torch.no_grad():
776
+ batch = input_ids.size(0)
777
+ seqlen = input_ids.size(1)
778
+ device = input_ids.device
779
+
780
+ # calc segment bucket
781
+ segment_rel_2d = torch.masked_fill(
782
+ segment[:, :, None] * num_segments[:, :, None]
783
+ + segment[:, None, :]
784
+ + segment_rel_offset[:, :, None],
785
+ ~(
786
+ (sample_ids[:, :, None] == sample_ids[:, None, :])
787
+ & (span[:, None, :] == span[:, :, None])
788
+ ), # not in the same span or sample
789
+ 0, # avoid torch.gather overflow
790
+ ).view(batch, seqlen * seqlen)
791
+
792
+ segment_bucket = torch.gather(
793
+ input=segment_rel,
794
+ dim=1,
795
+ index=segment_rel_2d.long(),
796
+ ).view(batch, seqlen, seqlen)
797
+
798
+ segment_bucket.masked_fill_(
799
+ ~(
800
+ (sample_ids[:, :, None] == sample_ids[:, None, :])
801
+ & (span[:, None, :] == span[:, :, None])
802
+ ), # not in the same span or sample
803
+ 1, # bucket is used for in-context samples
804
+ )
805
+
806
+ # directional mask
807
+ directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(
808
+ seqlen, device=device
809
+ ).view(-1, 1)
810
+ # sample mask
811
+ sample_mask_2d = (sample_ids[:, :, None] == 0) | (
812
+ sample_ids[:, :, None] == sample_ids[:, None, :]
813
+ )
814
+ # context mask
815
+ attention_mask = context[:, None, :] | (
816
+ context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
817
+ )
818
+ # span mask
819
+ attention_mask = (
820
+ attention_mask & sample_mask_2d & (span[:, None, :] == span[:, :, None])
821
+ )
822
+ # length mask
823
+ mask_1d = (
824
+ torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
825
+ )
826
+ attention_mask = (
827
+ mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask
828
+ )
829
+ position = torch.arange(seqlen, device=device).expand(batch, seqlen)
830
+
831
+ hidden_states = self.input_embedding(input_ids, input_id_sub)
832
+ position_bias = self.position_bias(position, position, segment_bucket)
833
+ hidden_states, present_key_values, all_hidden_states, all_attentions = self.encoder(
834
+ hidden_states,
835
+ attention_mask,
836
+ position_bias,
837
+ output_attentions,
838
+ output_hidden_states,
839
+ past_key_values=None,
840
+ use_cache=False
841
+ )
842
+
843
+ if not return_dict:
844
+ return tuple(
845
+ v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None
846
+ )
847
+
848
+ return BaseModelOutputWithPast(
849
+ last_hidden_state=hidden_states,
850
+ past_key_values=present_key_values,
851
+ hidden_states=all_hidden_states,
852
+ attentions=all_attentions,
853
+ )
854
+
855
+ def inference(
856
  self,
857
  input_ids: torch.Tensor,
858
  input_id_sub: Optional[torch.Tensor] = None,
 
1259
  config_class=_CONFIG_FOR_DOC,
1260
  )
1261
  def forward(
1262
+ self,
1263
+ input_ids: Optional[torch.Tensor] = None,
1264
+ input_id_sub: Optional[torch.Tensor] = None,
1265
+ length: Optional[torch.Tensor] = None,
1266
+ context: Optional[torch.Tensor] = None,
1267
+ sample_ids: Optional[torch.Tensor] = None,
1268
+ num_segments: Optional[torch.Tensor] = None,
1269
+ segment: Optional[torch.Tensor] = None,
1270
+ segment_rel_offset: Optional[torch.Tensor] = None,
1271
+ segment_rel: Optional[torch.Tensor] = None,
1272
+ span: Optional[torch.Tensor] = None,
1273
+ output_attentions: Optional[bool] = None,
1274
+ output_hidden_states: Optional[bool] = None,
1275
+ past_key_values: Optional[List] = None,
1276
+ use_cache: Optional[bool] = None,
1277
+ labels: Optional[torch.Tensor] = None,
1278
+ return_dict: Optional[bool] = None,
1279
+ ext_table_ids: Optional[torch.Tensor] = None, # (ext_table_size) int32
1280
+ ext_table_sub: Optional[torch.Tensor] = None, # (ext_table_size) int32
1281
+ **kwargs,
1282
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1283
+ r"""
1284
+ Args:
1285
+ input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
1286
+ Indices of input sequence tokens in the vocabulary.
1287
+
1288
+ Indices can be obtained using [`CPMBeeTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1289
+ [`PreTrainedTokenizer.__call__`] for details.
1290
+
1291
+ [What are input IDs?](../glossary#input-ids)
1292
+ input_id_sub (`torch.Tensor` of shape `(batch_size, seq_len)`):
1293
+ Subscription of input sequence tokens in the vocabulary.
1294
+
1295
+ Subscription of normal text will be zero while the special tokens of each group will be the 0, 1, 2,
1296
+ ... <ans_0>, <ans_1>, <ans_2> ... belongs to group <ans>. <mask_0>, <mask_1>, <mask_2> ... belongs to
1297
+ group <mask>.
1298
+ length (`torch.Tensor` of shape `(batch_size)`):
1299
+ The length of sequences in batch.
1300
+ context (`torch.Tensor` of shape `(batch_size, seq_len)`):
1301
+ Whether this token id is context or not. If is context, the value is 1. If not, the value is 0. If a
1302
+ token id is context, it does not need to be predicted.
1303
+ sample_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
1304
+ Give a sample id to every token id. The token ids with same sample ids belongs to the same sample.
1305
+ num_segments (`torch.Tensor` of shape `(batch_size, seq_len)`):
1306
+ Total number of segments in the current input.
1307
+ segment (`torch.Tensor` of shape `(batch_size, seq_len)`):
1308
+ Give a segment id to every token id. The token ids with same segment ids belongs to the same sample.
1309
+
1310
+ Generally, a string key or value in input data will be a segment. For example, input {"input": "hello,
1311
+ ", "<ans>": ""}, the segments includes: "input", "hello, ", "<ans>" and "".
1312
+ segment_rel_offset (`torch.Tensor` of shape `(batch_size, seq_len)`):
1313
+ The offset of segment rel.
1314
+ segment_rel (`torch.Tensor` of shape `(batch_size, seq_len)`):
1315
+ The segment relevance. A relative implementation of measuring the importance of segments.
1316
+ span (`Dict[str, Union[torch.Tensor, List]]`):
1317
+ Span will record every input_ids shape.
1318
+ output_attentions (`bool`, *optional*):
1319
+ Whether or not to return the attentions tensors of all attention layers.
1320
+ output_hidden_states (`bool`, *optional*):
1321
+ Whether or not to return the hidden states of all layers.
1322
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1323
+ A dummy arguments for CPMBee. The `past_states` contains pre-computed hidden-states (key and values in
1324
+ the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values`
1325
+ input) and other history arguments to speed up sequential decoding.
1326
+ use_cache (`bool`, *optional*):
1327
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1328
+ (see `past_key_values`).
1329
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1330
+ Labels for computing the masked language modeling loss.
1331
+ return_dict (`bool`, *optional*):
1332
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1333
+ ext_table_ids (`torch.Tensor`, *optional*):
1334
+ ext_table ids for embedding projection.
1335
+ ext_table_sub (`torch.Tensor`, *optional*):
1336
+ ext_table subscriptions for embedding projection.
1337
+ """
1338
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1339
+
1340
+ model_output = self.cpmbee(
1341
+ input_ids,
1342
+ input_id_sub,
1343
+ length,
1344
+ context,
1345
+ sample_ids,
1346
+ num_segments,
1347
+ segment,
1348
+ segment_rel_offset,
1349
+ segment_rel,
1350
+ span,
1351
+ output_attentions,
1352
+ output_hidden_states,
1353
+ past_key_values,
1354
+ use_cache,
1355
+ return_dict,
1356
+ )
1357
+ hidden_states = model_output.last_hidden_state if return_dict else model_output[0]
1358
+
1359
+ if ext_table_ids is not None:
1360
+ ext_table = self.cpmbee.input_embedding(ext_table_ids, ext_table_sub)
1361
+ else:
1362
+ ext_table = None
1363
+ logits = self.cpmbee.input_embedding.projection(hidden_states, ext_table)
1364
+
1365
+ loss = None
1366
+ if labels is not None:
1367
+ loss_func = nn.CrossEntropyLoss()
1368
+ loss = loss_func(logits.view(-1, logits.size(-1)), labels.long().view(-1))
1369
+
1370
+ if not return_dict:
1371
+ output = (logits,) + model_output[1:]
1372
+ return ((loss,) + output) if loss is not None else output
1373
+
1374
+ return CausalLMOutputWithPast(
1375
+ loss=loss,
1376
+ logits=logits,
1377
+ past_key_values=model_output.past_key_values,
1378
+ hidden_states=model_output.hidden_states,
1379
+ attentions=model_output.attentions,
1380
+ )
1381
+
1382
+ def inference(
1383
  self,
1384
  input_ids: Optional[torch.Tensor] = None,
1385
  input_id_sub: Optional[torch.Tensor] = None,
 
1474
  """
1475
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1476
 
1477
+ model_output = self.cpmbee.inference(
1478
  input_ids,
1479
  input_id_sub,
1480
  position,
 
1773
  # init inference
1774
  model_inputs, input_ids = self.prepare_inputs_for_generation(input_ids, batch_size, **model_kwargs)
1775
  pred_start_index = input_ids.size(-1)
1776
+ outputs = self.inference(
1777
  **model_inputs,
1778
  return_dict=True,
1779
  output_attentions=output_attentions,
 
1818
  input_ids, batch_size, beam_scorer, **model_kwargs
1819
  )
1820
 
1821
+ outputs = self.inference(
1822
  **model_inputs,
1823
  return_dict=True,
1824
  output_attentions=output_attentions,