Update attention.py
Browse files- attention.py +17 -0
attention.py
CHANGED
@@ -10,6 +10,23 @@ from torch import nn
|
|
10 |
from .fc import FC_CLASS_REGISTRY
|
11 |
from .norm import NORM_CLASS_REGISTRY
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool:
|
14 |
if original_is_causal and num_query_tokens != num_key_tokens:
|
15 |
if num_query_tokens != 1:
|
|
|
10 |
from .fc import FC_CLASS_REGISTRY
|
11 |
from .norm import NORM_CLASS_REGISTRY
|
12 |
|
13 |
+
|
14 |
+
def is_flash_v2_installed(v2_version: str='2.0.0'):
|
15 |
+
assert version.parse(v2_version) >= version.parse('2.0.0')
|
16 |
+
try:
|
17 |
+
import flash_attn as flash_attn
|
18 |
+
except:
|
19 |
+
return False
|
20 |
+
return version.parse(flash_attn.__version__) >= version.parse(v2_version)
|
21 |
+
|
22 |
+
def is_flash_v1_installed():
|
23 |
+
try:
|
24 |
+
import flash_attn as flash_attn
|
25 |
+
except:
|
26 |
+
return False
|
27 |
+
return version.parse(flash_attn.__version__) < version.parse('2.0.0')
|
28 |
+
|
29 |
+
|
30 |
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool:
|
31 |
if original_is_causal and num_query_tokens != num_key_tokens:
|
32 |
if num_query_tokens != 1:
|