kaizen9 commited on
Commit
5046ddf
1 Parent(s): 7aaf32c

Update attention.py

Browse files
Files changed (1) hide show
  1. attention.py +17 -0
attention.py CHANGED
@@ -10,6 +10,23 @@ from torch import nn
10
  from .fc import FC_CLASS_REGISTRY
11
  from .norm import NORM_CLASS_REGISTRY
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool:
14
  if original_is_causal and num_query_tokens != num_key_tokens:
15
  if num_query_tokens != 1:
 
10
  from .fc import FC_CLASS_REGISTRY
11
  from .norm import NORM_CLASS_REGISTRY
12
 
13
+
14
+ def is_flash_v2_installed(v2_version: str='2.0.0'):
15
+ assert version.parse(v2_version) >= version.parse('2.0.0')
16
+ try:
17
+ import flash_attn as flash_attn
18
+ except:
19
+ return False
20
+ return version.parse(flash_attn.__version__) >= version.parse(v2_version)
21
+
22
+ def is_flash_v1_installed():
23
+ try:
24
+ import flash_attn as flash_attn
25
+ except:
26
+ return False
27
+ return version.parse(flash_attn.__version__) < version.parse('2.0.0')
28
+
29
+
30
  def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool:
31
  if original_is_causal and num_query_tokens != num_key_tokens:
32
  if num_query_tokens != 1: