ccdv commited on
Commit
5ca8863
1 Parent(s): 69d00ee

bos_token + readme

Browse files
Files changed (1) hide show
  1. modeling_lsg_bart.py +38 -11
modeling_lsg_bart.py CHANGED
@@ -55,9 +55,9 @@ class LSGBartConfig(BartConfig):
55
  self.sparsity_factor = sparsity_factor
56
  self.sparsity_type = sparsity_type
57
 
58
- if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
59
  logger.warning(
60
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
61
  setting sparsity_type=None, computation will skip sparse attention")
62
  self.sparsity_type = None
63
 
@@ -345,7 +345,7 @@ class LSGAttentionProduct(nn.Module):
345
  return x.reshape(*x.size()[:-2], n_blocks, -1, d)
346
 
347
 
348
- class LSGBartEncoderAttention(BaseSelfAttention):
349
  '''
350
  Compute local attention with overlapping blocs
351
  Use global attention for tokens with highest norm
@@ -380,15 +380,16 @@ class LSGBartEncoderAttention(BaseSelfAttention):
380
  "lsh": self.get_sparse_tokens_with_lsh,
381
  "stride": self.get_sparse_tokens_with_stride,
382
  "block_stride": self.get_sparse_tokens_with_block_stride,
 
383
  }
384
 
385
  self.sparsity_type = config.sparsity_type
386
- self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
387
 
388
  if config.sparsity_type == "lsh":
389
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
390
 
391
- def get_sparse_tokens_with_norm(self, keys, values, mask):
392
 
393
  if self.sparsity_factor == 1:
394
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -416,7 +417,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
416
 
417
  return keys, values, mask
418
 
419
- def get_sparse_tokens_with_pooling(self, keys, values, mask):
420
 
421
  if self.sparsity_factor == 1:
422
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -439,7 +440,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
439
  mask *= torch.finfo(mask.dtype).min
440
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
441
 
442
- def get_sparse_tokens_with_stride(self, keys, values, mask):
443
 
444
  if self.sparsity_factor == 1:
445
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -455,7 +456,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
455
 
456
  return keys, values, mask
457
 
458
- def get_sparse_tokens_with_block_stride(self, keys, values, mask):
459
 
460
  if self.sparsity_factor == 1:
461
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -475,11 +476,14 @@ class LSGBartEncoderAttention(BaseSelfAttention):
475
 
476
  return keys, values, mask
477
 
478
- def get_sparse_tokens_with_lsh(self, keys, values, mask):
479
 
480
  if self.sparsity_factor == 1:
481
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
482
 
 
 
 
483
  block_size = min(self.block_size, self.sparse_block_size)
484
  keys = self.chunk(keys, block_size)
485
  values = self.chunk(values, block_size)
@@ -526,6 +530,29 @@ class LSGBartEncoderAttention(BaseSelfAttention):
526
 
527
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  def forward(
530
  self,
531
  hidden_states,
@@ -595,7 +622,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
595
  sparse_key, sparse_value, sparse_mask = (None, None, None)
596
 
597
  if self.sparse_block_size and self.sparsity_factor > 0:
598
- sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
599
 
600
  # Expand masks on heads
601
  attention_mask = attention_mask.expand(-1, h, -1, -1)
@@ -632,7 +659,7 @@ class LSGBartEncoderLayer(BartEncoderLayer):
632
  def __init__(self, config):
633
 
634
  super().__init__(config)
635
- self.self_attn = LSGBartEncoderAttention(
636
  config=config,
637
  embed_dim=self.embed_dim,
638
  num_heads=config.encoder_attention_heads,
 
55
  self.sparsity_factor = sparsity_factor
56
  self.sparsity_type = sparsity_type
57
 
58
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride", "bos_pooling"]:
59
  logger.warning(
60
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride', 'bos_pooling'], \
61
  setting sparsity_type=None, computation will skip sparse attention")
62
  self.sparsity_type = None
63
 
 
345
  return x.reshape(*x.size()[:-2], n_blocks, -1, d)
346
 
347
 
348
+ class LSGBartEncoderSelfAttention(BaseSelfAttention):
349
  '''
350
  Compute local attention with overlapping blocs
351
  Use global attention for tokens with highest norm
 
380
  "lsh": self.get_sparse_tokens_with_lsh,
381
  "stride": self.get_sparse_tokens_with_stride,
382
  "block_stride": self.get_sparse_tokens_with_block_stride,
383
+ "bos_pooling": self.get_sparse_tokens_with_bos_pooling
384
  }
385
 
386
  self.sparsity_type = config.sparsity_type
387
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda w, x, y, z: (None, None, None))
388
 
389
  if config.sparsity_type == "lsh":
390
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
391
 
392
+ def get_sparse_tokens_with_norm(self, queries, keys, values, mask):
393
 
394
  if self.sparsity_factor == 1:
395
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
417
 
418
  return keys, values, mask
419
 
420
+ def get_sparse_tokens_with_pooling(self, queries, keys, values, mask):
421
 
422
  if self.sparsity_factor == 1:
423
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
440
  mask *= torch.finfo(mask.dtype).min
441
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
442
 
443
+ def get_sparse_tokens_with_stride(self, queries, keys, values, mask):
444
 
445
  if self.sparsity_factor == 1:
446
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
456
 
457
  return keys, values, mask
458
 
459
+ def get_sparse_tokens_with_block_stride(self, queries, keys, values, mask):
460
 
461
  if self.sparsity_factor == 1:
462
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
476
 
477
  return keys, values, mask
478
 
479
+ def get_sparse_tokens_with_lsh(self, queries, keys, values, mask):
480
 
481
  if self.sparsity_factor == 1:
482
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
483
 
484
+ if self.sparsity_factor == self.sparse_block_size:
485
+ return self.get_sparse_tokens_with_bos_pooling(queries, keys, values, mask)
486
+
487
  block_size = min(self.block_size, self.sparse_block_size)
488
  keys = self.chunk(keys, block_size)
489
  values = self.chunk(values, block_size)
 
530
 
531
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
532
 
533
+ def get_sparse_tokens_with_bos_pooling(self, queries, keys, values, mask):
534
+
535
+ if self.sparsity_factor == 1:
536
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
537
+
538
+ queries = queries.unsqueeze(-3)
539
+ mask = self.chunk(mask.transpose(-1, -2), self.sparsity_factor).transpose(-1, -2)
540
+ keys = self.chunk(keys, self.sparsity_factor)
541
+ values = self.chunk(values, self.sparsity_factor)
542
+
543
+ n, h, b, t, d = keys.size()
544
+ scores = (queries[..., :1, :] @ keys.transpose(-1, -2)) / math.sqrt(d)
545
+ if mask is not None:
546
+ scores = scores + mask
547
+
548
+ scores = torch.softmax(scores, dim=-1)
549
+ keys = scores @ keys
550
+ values = scores @ values
551
+ mask = mask.mean(dim=-1)
552
+ mask[mask != torch.finfo(mask.dtype).min] = 0
553
+
554
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
555
+
556
  def forward(
557
  self,
558
  hidden_states,
 
622
  sparse_key, sparse_value, sparse_mask = (None, None, None)
623
 
624
  if self.sparse_block_size and self.sparsity_factor > 0:
625
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
626
 
627
  # Expand masks on heads
628
  attention_mask = attention_mask.expand(-1, h, -1, -1)
 
659
  def __init__(self, config):
660
 
661
  super().__init__(config)
662
+ self.self_attn = LSGBartEncoderSelfAttention(
663
  config=config,
664
  embed_dim=self.embed_dim,
665
  num_heads=config.encoder_attention_heads,