xinyu1205 commited on
Commit
64c1dc7
1 Parent(s): eee384b

Rename models/med.py to models/bert.py

Browse files
Files changed (1) hide show
  1. models/{med.py → bert.py} +8 -4
models/{med.py → bert.py} RENAMED
@@ -224,6 +224,12 @@ class BertSelfAttention(nn.Module):
224
 
225
  past_key_value = (key_layer, value_layer)
226
 
 
 
 
 
 
 
227
  # Take the dot product between "query" and "key" to get the raw attention scores.
228
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
229
 
@@ -392,12 +398,10 @@ class BertLayer(nn.Module):
392
  mode=None,
393
  ):
394
 
395
- if mode == 'mlr':
396
-
397
  assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
398
 
399
- # print('attention_output.shape',attention_output.shape)
400
- # print('encoder_hidden_states.shape',encoder_hidden_states.shape)
401
  cross_attention_outputs = self.crossattention(
402
  hidden_states,
403
  attention_mask,
 
224
 
225
  past_key_value = (key_layer, value_layer)
226
 
227
+ # compatible with higher versions of transformers
228
+ if key_layer.shape[0] > query_layer.shape[0]:
229
+ key_layer = key_layer[:query_layer.shape[0], :, :, :]
230
+ attention_mask = attention_mask[:query_layer.shape[0], :, :]
231
+ value_layer = value_layer[:query_layer.shape[0], :, :, :]
232
+
233
  # Take the dot product between "query" and "key" to get the raw attention scores.
234
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
235
 
 
398
  mode=None,
399
  ):
400
 
401
+ if mode == 'tagging':
402
+
403
  assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
404
 
 
 
405
  cross_attention_outputs = self.crossattention(
406
  hidden_states,
407
  attention_mask,