Update modeling_qwen2_rm.py
Browse files- modeling_qwen2_rm.py +2 -2
modeling_qwen2_rm.py
CHANGED
@@ -48,8 +48,8 @@ from transformers.utils import (
|
|
48 |
from .configuration_qwen2_rm import Qwen2RMConfig as Qwen2Config
|
49 |
|
50 |
|
51 |
-
|
52 |
-
|
53 |
|
54 |
|
55 |
logger = logging.get_logger(__name__)
|
|
|
48 |
from .configuration_qwen2_rm import Qwen2RMConfig as Qwen2Config
|
49 |
|
50 |
|
51 |
+
if is_flash_attn_2_available():
|
52 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
53 |
|
54 |
|
55 |
logger = logging.get_logger(__name__)
|