reference the flash attention GitHub
Browse files- bert_padding.py +5 -0
- block.py +5 -0
- embedding.py +5 -0
- mha.py +9 -0
- mlp.py +5 -0
bert_padding.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1 |
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
2 |
|
|
|
|
|
|
|
|
|
|
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
from einops import rearrange, repeat
|
|
|
1 |
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
2 |
|
3 |
+
""""
|
4 |
+
The implementation was further adapted from
|
5 |
+
https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
|
6 |
+
"""
|
7 |
+
|
8 |
import torch
|
9 |
import torch.nn.functional as F
|
10 |
from einops import rearrange, repeat
|
block.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1 |
# Copyright (c) 2024, Tri Dao.
|
2 |
|
|
|
|
|
|
|
|
|
|
|
3 |
from functools import partial
|
4 |
from typing import Optional
|
5 |
|
|
|
1 |
# Copyright (c) 2024, Tri Dao.
|
2 |
|
3 |
+
""""
|
4 |
+
The implementation was adopted from
|
5 |
+
https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
|
6 |
+
"""
|
7 |
+
|
8 |
from functools import partial
|
9 |
from typing import Optional
|
10 |
|
embedding.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1 |
# Copyright (c) 2022, Tri Dao.
|
2 |
|
|
|
|
|
|
|
|
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from torch import Tensor
|
|
|
1 |
# Copyright (c) 2022, Tri Dao.
|
2 |
|
3 |
+
""""
|
4 |
+
The implementation was adopted from
|
5 |
+
https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0/flash_attn/models/bert.py
|
6 |
+
"""
|
7 |
+
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
from torch import Tensor
|
mha.py
CHANGED
@@ -1,5 +1,14 @@
|
|
1 |
# Copyright (c) 2023, Tri Dao.
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import math
|
4 |
from functools import partial
|
5 |
|
|
|
1 |
# Copyright (c) 2023, Tri Dao.
|
2 |
|
3 |
+
""""
|
4 |
+
The implementation was adopted from
|
5 |
+
https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
|
6 |
+
and made modifications to
|
7 |
+
- support QK normalization
|
8 |
+
- make ALiBi run with MHA (needed to cast alibi slopes to fp32)
|
9 |
+
- make ALiBi run on CPU
|
10 |
+
"""
|
11 |
+
|
12 |
import math
|
13 |
from functools import partial
|
14 |
|
mlp.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1 |
# Copyright (c) 2023, Tri Dao.
|
2 |
|
|
|
|
|
|
|
|
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
|
|
1 |
# Copyright (c) 2023, Tri Dao.
|
2 |
|
3 |
+
""""
|
4 |
+
The implementation was adopted from
|
5 |
+
https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
|
6 |
+
"""
|
7 |
+
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|