Spaces:
Running
Running
Upload 116 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- mamba/.github/workflows/publish.yaml +209 -0
- mamba/.gitignore +6 -0
- mamba/.gitmodules +3 -0
- mamba/AUTHORS +2 -0
- mamba/LICENSE +201 -0
- mamba/README.md +243 -0
- mamba/assets/selection.png +0 -0
- mamba/assets/ssd_algorithm.png +3 -0
- mamba/benchmarks/benchmark_generation_mamba_simple.py +92 -0
- mamba/build/lib/mamba_ssm/__init__.py +6 -0
- mamba/build/lib/mamba_ssm/distributed/__init__.py +0 -0
- mamba/build/lib/mamba_ssm/distributed/distributed_utils.py +144 -0
- mamba/build/lib/mamba_ssm/distributed/tensor_parallel.py +296 -0
- mamba/build/lib/mamba_ssm/models/__init__.py +0 -0
- mamba/build/lib/mamba_ssm/models/config_mamba.py +18 -0
- mamba/build/lib/mamba_ssm/models/mixer_seq_simple.py +309 -0
- mamba/build/lib/mamba_ssm/modules/__init__.py +0 -0
- mamba/build/lib/mamba_ssm/modules/block.py +91 -0
- mamba/build/lib/mamba_ssm/modules/mamba2.py +383 -0
- mamba/build/lib/mamba_ssm/modules/mamba2_simple.py +200 -0
- mamba/build/lib/mamba_ssm/modules/mamba_simple.py +294 -0
- mamba/build/lib/mamba_ssm/modules/mha.py +294 -0
- mamba/build/lib/mamba_ssm/modules/mlp.py +34 -0
- mamba/build/lib/mamba_ssm/modules/ssd_minimal.py +103 -0
- mamba/build/lib/mamba_ssm/ops/__init__.py +0 -0
- mamba/build/lib/mamba_ssm/ops/selective_scan_interface.py +357 -0
- mamba/build/lib/mamba_ssm/ops/triton/__init__.py +0 -0
- mamba/build/lib/mamba_ssm/ops/triton/k_activations.py +169 -0
- mamba/build/lib/mamba_ssm/ops/triton/layer_norm.py +1113 -0
- mamba/build/lib/mamba_ssm/ops/triton/layernorm_gated.py +437 -0
- mamba/build/lib/mamba_ssm/ops/triton/selective_state_update.py +265 -0
- mamba/build/lib/mamba_ssm/ops/triton/softplus.py +17 -0
- mamba/build/lib/mamba_ssm/ops/triton/ssd_bmm.py +262 -0
- mamba/build/lib/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
- mamba/build/lib/mamba_ssm/ops/triton/ssd_chunk_state.py +988 -0
- mamba/build/lib/mamba_ssm/ops/triton/ssd_combined.py +981 -0
- mamba/build/lib/mamba_ssm/ops/triton/ssd_state_passing.py +348 -0
- mamba/build/lib/mamba_ssm/utils/__init__.py +0 -0
- mamba/build/lib/mamba_ssm/utils/generation.py +387 -0
- mamba/build/lib/mamba_ssm/utils/hf.py +23 -0
- mamba/csrc/selective_scan/reverse_scan.cuh +415 -0
- mamba/csrc/selective_scan/selective_scan.cpp +497 -0
- mamba/csrc/selective_scan/selective_scan.h +101 -0
- mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu +9 -0
- mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu +9 -0
- mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu +9 -0
- mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu +9 -0
- mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu +9 -0
- mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu +9 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
mamba/assets/ssd_algorithm.png filter=lfs diff=lfs merge=lfs -text
|
mamba/.github/workflows/publish.yaml
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This workflow will:
|
2 |
+
# - Create a new Github release
|
3 |
+
# - Build wheels for supported architectures
|
4 |
+
# - Deploy the wheels to the Github release
|
5 |
+
# - Release the static code to PyPi
|
6 |
+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
7 |
+
|
8 |
+
name: Build wheels and deploy
|
9 |
+
|
10 |
+
on:
|
11 |
+
create:
|
12 |
+
tags:
|
13 |
+
- v*
|
14 |
+
|
15 |
+
jobs:
|
16 |
+
|
17 |
+
setup_release:
|
18 |
+
name: Create Release
|
19 |
+
runs-on: ubuntu-latest
|
20 |
+
steps:
|
21 |
+
- name: Get the tag version
|
22 |
+
id: extract_branch
|
23 |
+
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
24 |
+
shell: bash
|
25 |
+
|
26 |
+
- name: Create Release
|
27 |
+
id: create_release
|
28 |
+
uses: actions/create-release@v1
|
29 |
+
env:
|
30 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
31 |
+
with:
|
32 |
+
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
33 |
+
release_name: ${{ steps.extract_branch.outputs.branch }}
|
34 |
+
|
35 |
+
build_wheels:
|
36 |
+
name: Build Wheel
|
37 |
+
needs: setup_release
|
38 |
+
runs-on: ${{ matrix.os }}
|
39 |
+
|
40 |
+
strategy:
|
41 |
+
fail-fast: false
|
42 |
+
matrix:
|
43 |
+
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
|
44 |
+
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
|
45 |
+
os: [ubuntu-20.04]
|
46 |
+
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
|
47 |
+
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0']
|
48 |
+
cuda-version: ['11.8.0', '12.2.2']
|
49 |
+
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
|
50 |
+
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
|
51 |
+
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
|
52 |
+
# when building without C++11 ABI and using it on nvcr images.
|
53 |
+
cxx11_abi: ['FALSE', 'TRUE']
|
54 |
+
exclude:
|
55 |
+
# Pytorch < 2.2 does not support Python 3.12
|
56 |
+
- torch-version: '2.0.1'
|
57 |
+
python-version: '3.12'
|
58 |
+
- torch-version: '2.1.2'
|
59 |
+
python-version: '3.12'
|
60 |
+
# Pytorch <= 2.0 only supports CUDA <= 11.8
|
61 |
+
- torch-version: '2.0.1'
|
62 |
+
cuda-version: '12.2.2'
|
63 |
+
|
64 |
+
steps:
|
65 |
+
- name: Checkout
|
66 |
+
uses: actions/checkout@v3
|
67 |
+
|
68 |
+
- name: Set up Python
|
69 |
+
uses: actions/setup-python@v4
|
70 |
+
with:
|
71 |
+
python-version: ${{ matrix.python-version }}
|
72 |
+
|
73 |
+
- name: Set CUDA and PyTorch versions
|
74 |
+
run: |
|
75 |
+
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
|
76 |
+
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
|
77 |
+
|
78 |
+
- name: Free up disk space
|
79 |
+
if: ${{ runner.os == 'Linux' }}
|
80 |
+
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
|
81 |
+
# https://github.com/easimon/maximize-build-space/tree/test-report
|
82 |
+
run: |
|
83 |
+
sudo rm -rf /usr/share/dotnet
|
84 |
+
sudo rm -rf /opt/ghc
|
85 |
+
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
86 |
+
|
87 |
+
- name: Set up swap space
|
88 |
+
if: runner.os == 'Linux'
|
89 |
+
uses: pierotofy/[email protected]
|
90 |
+
with:
|
91 |
+
swap-size-gb: 10
|
92 |
+
|
93 |
+
- name: Install CUDA ${{ matrix.cuda-version }}
|
94 |
+
if: ${{ matrix.cuda-version != 'cpu' }}
|
95 |
+
uses: Jimver/[email protected]
|
96 |
+
id: cuda-toolkit
|
97 |
+
with:
|
98 |
+
cuda: ${{ matrix.cuda-version }}
|
99 |
+
linux-local-args: '["--toolkit"]'
|
100 |
+
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
|
101 |
+
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
|
102 |
+
method: 'network'
|
103 |
+
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
|
104 |
+
# not just nvcc
|
105 |
+
# sub-packages: '["nvcc"]'
|
106 |
+
|
107 |
+
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
|
108 |
+
run: |
|
109 |
+
pip install --upgrade pip
|
110 |
+
# If we don't install before installing Pytorch, we get error for torch 2.0.1
|
111 |
+
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
|
112 |
+
pip install lit
|
113 |
+
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
|
114 |
+
pip install setuptools
|
115 |
+
# We want to figure out the CUDA version to download pytorch
|
116 |
+
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
|
117 |
+
# This code is ugly, maybe there's a better way to do this.
|
118 |
+
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
|
119 |
+
minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
|
120 |
+
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124}[env['MATRIX_TORCH_VERSION']]; \
|
121 |
+
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
|
122 |
+
)
|
123 |
+
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
|
124 |
+
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
|
125 |
+
else
|
126 |
+
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
|
127 |
+
fi
|
128 |
+
nvcc --version
|
129 |
+
python --version
|
130 |
+
python -c "import torch; print('PyTorch:', torch.__version__)"
|
131 |
+
python -c "import torch; print('CUDA:', torch.version.cuda)"
|
132 |
+
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
133 |
+
shell:
|
134 |
+
bash
|
135 |
+
|
136 |
+
- name: Build wheel
|
137 |
+
run: |
|
138 |
+
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
|
139 |
+
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
|
140 |
+
# However this still fails so I'm using a newer version of setuptools
|
141 |
+
pip install setuptools==68.0.0
|
142 |
+
pip install ninja packaging wheel
|
143 |
+
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
|
144 |
+
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
145 |
+
# Limit MAX_JOBS otherwise the github runner goes OOM
|
146 |
+
MAX_JOBS=2 MAMBA_FORCE_BUILD="TRUE" MAMBA_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
|
147 |
+
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
|
148 |
+
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
|
149 |
+
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
|
150 |
+
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
151 |
+
|
152 |
+
- name: Log Built Wheels
|
153 |
+
run: |
|
154 |
+
ls dist
|
155 |
+
|
156 |
+
- name: Get the tag version
|
157 |
+
id: extract_branch
|
158 |
+
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
159 |
+
|
160 |
+
- name: Get Release with tag
|
161 |
+
id: get_current_release
|
162 |
+
uses: joutvhu/get-release@v1
|
163 |
+
with:
|
164 |
+
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
165 |
+
env:
|
166 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
167 |
+
|
168 |
+
- name: Upload Release Asset
|
169 |
+
id: upload_release_asset
|
170 |
+
uses: actions/upload-release-asset@v1
|
171 |
+
env:
|
172 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
173 |
+
with:
|
174 |
+
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
|
175 |
+
asset_path: ./dist/${{env.wheel_name}}
|
176 |
+
asset_name: ${{env.wheel_name}}
|
177 |
+
asset_content_type: application/*
|
178 |
+
|
179 |
+
publish_package:
|
180 |
+
name: Publish package
|
181 |
+
needs: [build_wheels]
|
182 |
+
|
183 |
+
runs-on: ubuntu-latest
|
184 |
+
|
185 |
+
steps:
|
186 |
+
- uses: actions/checkout@v3
|
187 |
+
|
188 |
+
- uses: actions/setup-python@v4
|
189 |
+
with:
|
190 |
+
python-version: '3.10'
|
191 |
+
|
192 |
+
- name: Install dependencies
|
193 |
+
run: |
|
194 |
+
pip install ninja packaging setuptools wheel twine
|
195 |
+
# We don't want to download anything CUDA-related here
|
196 |
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
197 |
+
|
198 |
+
- name: Build core package
|
199 |
+
env:
|
200 |
+
MAMBA_SKIP_CUDA_BUILD: "TRUE"
|
201 |
+
run: |
|
202 |
+
python setup.py sdist --dist-dir=dist
|
203 |
+
|
204 |
+
- name: Deploy
|
205 |
+
env:
|
206 |
+
TWINE_USERNAME: "__token__"
|
207 |
+
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
208 |
+
run: |
|
209 |
+
python -m twine upload dist/*
|
mamba/.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*__pycache__/
|
2 |
+
*.egg-info/
|
3 |
+
build/
|
4 |
+
**.so
|
5 |
+
*.hip
|
6 |
+
*_hip.*
|
mamba/.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "3rdparty/lm-evaluation-harness"]
|
2 |
+
path = 3rdparty/lm-evaluation-harness
|
3 |
+
url = https://github.com/EleutherAI/lm-evaluation-harness/
|
mamba/AUTHORS
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Tri Dao, [email protected]
|
2 |
+
Albert Gu, [email protected]
|
mamba/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2023 Tri Dao, Albert Gu
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
mamba/README.md
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Mamba
|
2 |
+
|
3 |
+
![Mamba](assets/selection.png "Selective State Space")
|
4 |
+
> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
|
5 |
+
> Albert Gu*, Tri Dao*\
|
6 |
+
> Paper: https://arxiv.org/abs/2312.00752
|
7 |
+
|
8 |
+
![Mamba-2](assets/ssd_algorithm.png "State Space Dual Model")
|
9 |
+
> **Transformers are SSMs: Generalized Models and Efficient Algorithms**\
|
10 |
+
> **Through Structured State Space Duality**\
|
11 |
+
> Tri Dao*, Albert Gu*\
|
12 |
+
> Paper: https://arxiv.org/abs/2405.21060
|
13 |
+
|
14 |
+
## About
|
15 |
+
|
16 |
+
Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
|
17 |
+
It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
|
18 |
+
with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
|
19 |
+
|
20 |
+
## Installation
|
21 |
+
|
22 |
+
- [Option] `pip install causal-conv1d>=1.4.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
|
23 |
+
- `pip install mamba-ssm`: the core Mamba package.
|
24 |
+
- `pip install mamba-ssm[causal-conv1d]`: To install core Mamba package and causal-conv1d.
|
25 |
+
- `pip install mamba-ssm[dev]`: To install core Mamba package and dev depdencies.
|
26 |
+
|
27 |
+
It can also be built from source with `pip install .` from this repository.
|
28 |
+
|
29 |
+
If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
|
30 |
+
|
31 |
+
Other requirements:
|
32 |
+
- Linux
|
33 |
+
- NVIDIA GPU
|
34 |
+
- PyTorch 1.12+
|
35 |
+
- CUDA 11.6+
|
36 |
+
|
37 |
+
For AMD cards, see additional prerequisites below.
|
38 |
+
|
39 |
+
## Usage
|
40 |
+
|
41 |
+
We expose several levels of interface with the Mamba model.
|
42 |
+
|
43 |
+
### Selective SSM
|
44 |
+
|
45 |
+
Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
|
46 |
+
|
47 |
+
Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
|
48 |
+
|
49 |
+
### Mamba Block
|
50 |
+
|
51 |
+
The main module of this repository is the Mamba architecture block wrapping the selective SSM.
|
52 |
+
|
53 |
+
Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
|
54 |
+
|
55 |
+
Usage:
|
56 |
+
``` python
|
57 |
+
import torch
|
58 |
+
from mamba_ssm import Mamba
|
59 |
+
|
60 |
+
batch, length, dim = 2, 64, 16
|
61 |
+
x = torch.randn(batch, length, dim).to("cuda")
|
62 |
+
model = Mamba(
|
63 |
+
# This module uses roughly 3 * expand * d_model^2 parameters
|
64 |
+
d_model=dim, # Model dimension d_model
|
65 |
+
d_state=16, # SSM state expansion factor
|
66 |
+
d_conv=4, # Local convolution width
|
67 |
+
expand=2, # Block expansion factor
|
68 |
+
).to("cuda")
|
69 |
+
y = model(x)
|
70 |
+
assert y.shape == x.shape
|
71 |
+
```
|
72 |
+
|
73 |
+
### Mamba-2
|
74 |
+
|
75 |
+
The Mamba-2 block is implemented at [modules/mamba2.py](mamba_ssm/modules/mamba2.py).
|
76 |
+
|
77 |
+
A simpler version is at [modules/mamba2_simple.py](mamba_ssm/modules/mamba2_simple.py)
|
78 |
+
|
79 |
+
The usage is similar to Mamba(-1):
|
80 |
+
``` python
|
81 |
+
from mamba_ssm import Mamba2
|
82 |
+
model = Mamba2(
|
83 |
+
# This module uses roughly 3 * expand * d_model^2 parameters
|
84 |
+
d_model=dim, # Model dimension d_model
|
85 |
+
d_state=64, # SSM state expansion factor, typically 64 or 128
|
86 |
+
d_conv=4, # Local convolution width
|
87 |
+
expand=2, # Block expansion factor
|
88 |
+
).to("cuda")
|
89 |
+
y = model(x)
|
90 |
+
assert y.shape == x.shape
|
91 |
+
```
|
92 |
+
|
93 |
+
#### SSD
|
94 |
+
|
95 |
+
A minimal version of the inner SSD module (Listing 1 from the Mamba-2 paper) with conversion between "discrete" and "continuous" SSM versions
|
96 |
+
is at [modules/ssd_minimal.py](mamba_ssm/modules/ssd_minimal.py).
|
97 |
+
|
98 |
+
### Mamba Language Model
|
99 |
+
|
100 |
+
Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
|
101 |
+
|
102 |
+
Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
|
103 |
+
|
104 |
+
This is an example of how to integrate Mamba into an end-to-end neural network.
|
105 |
+
This example is used in the generation scripts below.
|
106 |
+
|
107 |
+
|
108 |
+
## Pretrained Models
|
109 |
+
|
110 |
+
Pretrained models are uploaded to
|
111 |
+
[Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
|
112 |
+
`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, `mamba2-130m`, `mamba2-370m`,
|
113 |
+
`mamba2-780m`, `mamba2-1.3b`, `mamba2-2.7b`, `transformerpp-2.7b`, `mamba2attn-2.7b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`
|
114 |
+
(trained on 600B tokens on the SlimPajama dataset).
|
115 |
+
|
116 |
+
|
117 |
+
The models will be autodownloaded by the generation script below.
|
118 |
+
|
119 |
+
These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
|
120 |
+
|
121 |
+
| Parameters | Layers | Model dim. |
|
122 |
+
|------------|--------|------------|
|
123 |
+
| 130M | 24 | 768 |
|
124 |
+
| 370M | 48 | 1024 |
|
125 |
+
| 790M | 48 | 1536 |
|
126 |
+
| 1.4B | 48 | 2048 |
|
127 |
+
| 2.8B | 64 | 2560 |
|
128 |
+
|
129 |
+
(The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
|
130 |
+
|
131 |
+
Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
|
132 |
+
Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
|
133 |
+
|
134 |
+
|
135 |
+
## Evaluations
|
136 |
+
|
137 |
+
To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
|
138 |
+
we use the
|
139 |
+
[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)
|
140 |
+
library.
|
141 |
+
|
142 |
+
1. Install `lm-evaluation-harness` by `pip install lm-eval==0.4.2`.
|
143 |
+
2. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
|
144 |
+
``` sh
|
145 |
+
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
|
146 |
+
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
|
147 |
+
```
|
148 |
+
|
149 |
+
To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts:
|
150 |
+
``` sh
|
151 |
+
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
|
152 |
+
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
|
153 |
+
```
|
154 |
+
|
155 |
+
To run evaluations on Mamba-2 models, simply replace the model names:
|
156 |
+
``` sh
|
157 |
+
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
|
158 |
+
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
|
159 |
+
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
|
160 |
+
```
|
161 |
+
|
162 |
+
Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
|
163 |
+
|
164 |
+
## Inference
|
165 |
+
|
166 |
+
The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
|
167 |
+
1. autoloads a model from the Hugging Face Hub,
|
168 |
+
2. generates completions of a user-specified prompt,
|
169 |
+
3. benchmarks the inference speed of this generation.
|
170 |
+
|
171 |
+
Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
|
172 |
+
|
173 |
+
### Examples
|
174 |
+
|
175 |
+
To test generation latency (e.g. batch size = 1) with different sampling strategies:
|
176 |
+
|
177 |
+
``` sh
|
178 |
+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
|
179 |
+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
|
180 |
+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
|
181 |
+
```
|
182 |
+
|
183 |
+
To test generation throughput with random prompts (e.g. large batch size):
|
184 |
+
``` sh
|
185 |
+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64
|
186 |
+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64
|
187 |
+
```
|
188 |
+
|
189 |
+
With Mamba-2, you just need to change the model name:
|
190 |
+
``` sh
|
191 |
+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
|
192 |
+
```
|
193 |
+
|
194 |
+
|
195 |
+
## Troubleshooting
|
196 |
+
|
197 |
+
### Precision
|
198 |
+
Our models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary.
|
199 |
+
On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation).
|
200 |
+
|
201 |
+
We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities,
|
202 |
+
as a first step please try a framework storing parameters in fp32 (such as AMP).
|
203 |
+
|
204 |
+
### Initialization
|
205 |
+
Some parts of the model have initializations inherited from prior work on S4 models.
|
206 |
+
For [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection.
|
207 |
+
However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero).
|
208 |
+
If this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework)
|
209 |
+
that is specific to the training framework.
|
210 |
+
|
211 |
+
## Additional Prerequisites for AMD cards
|
212 |
+
|
213 |
+
### Patching ROCm
|
214 |
+
|
215 |
+
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
|
216 |
+
|
217 |
+
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
|
218 |
+
|
219 |
+
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
|
220 |
+
```bash
|
221 |
+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
|
222 |
+
```
|
223 |
+
|
224 |
+
|
225 |
+
## Citation
|
226 |
+
|
227 |
+
If you use this codebase, or otherwise find our work valuable, please cite Mamba:
|
228 |
+
```
|
229 |
+
@article{mamba,
|
230 |
+
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
|
231 |
+
author={Gu, Albert and Dao, Tri},
|
232 |
+
journal={arXiv preprint arXiv:2312.00752},
|
233 |
+
year={2023}
|
234 |
+
}
|
235 |
+
|
236 |
+
@inproceedings{mamba2,
|
237 |
+
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
|
238 |
+
author={Dao, Tri and Gu, Albert},
|
239 |
+
booktitle={International Conference on Machine Learning (ICML)},
|
240 |
+
year={2024}
|
241 |
+
}
|
242 |
+
|
243 |
+
```
|
mamba/assets/selection.png
ADDED
mamba/assets/ssd_algorithm.png
ADDED
Git LFS Details
|
mamba/benchmarks/benchmark_generation_mamba_simple.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
13 |
+
|
14 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
15 |
+
|
16 |
+
|
17 |
+
parser = argparse.ArgumentParser(description="Generation benchmarking")
|
18 |
+
parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
|
19 |
+
parser.add_argument("--prompt", type=str, default=None)
|
20 |
+
parser.add_argument("--promptlen", type=int, default=100)
|
21 |
+
parser.add_argument("--genlen", type=int, default=100)
|
22 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
23 |
+
parser.add_argument("--topk", type=int, default=1)
|
24 |
+
parser.add_argument("--topp", type=float, default=1.0)
|
25 |
+
parser.add_argument("--minp", type=float, default=0.0)
|
26 |
+
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
27 |
+
parser.add_argument("--batch", type=int, default=1)
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
repeats = 3
|
31 |
+
device = "cuda"
|
32 |
+
dtype = torch.float16
|
33 |
+
|
34 |
+
print(f"Loading model {args.model_name}")
|
35 |
+
is_mamba = args.model_name.startswith("state-spaces/mamba") or args.model_name.startswith("state-spaces/transformerpp")
|
36 |
+
if is_mamba:
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
38 |
+
model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
|
39 |
+
else:
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
41 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
|
42 |
+
model.eval()
|
43 |
+
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
44 |
+
|
45 |
+
torch.random.manual_seed(0)
|
46 |
+
if args.prompt is None:
|
47 |
+
input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
|
48 |
+
attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
|
49 |
+
else:
|
50 |
+
tokens = tokenizer(args.prompt, return_tensors="pt")
|
51 |
+
input_ids = tokens.input_ids.to(device=device)
|
52 |
+
attn_mask = tokens.attention_mask.to(device=device)
|
53 |
+
max_length = input_ids.shape[1] + args.genlen
|
54 |
+
|
55 |
+
if is_mamba:
|
56 |
+
fn = lambda: model.generate(
|
57 |
+
input_ids=input_ids,
|
58 |
+
max_length=max_length,
|
59 |
+
cg=True,
|
60 |
+
return_dict_in_generate=True,
|
61 |
+
output_scores=True,
|
62 |
+
enable_timing=False,
|
63 |
+
temperature=args.temperature,
|
64 |
+
top_k=args.topk,
|
65 |
+
top_p=args.topp,
|
66 |
+
min_p=args.minp,
|
67 |
+
repetition_penalty=args.repetition_penalty,
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
fn = lambda: model.generate(
|
71 |
+
input_ids=input_ids,
|
72 |
+
attention_mask=attn_mask,
|
73 |
+
max_length=max_length,
|
74 |
+
return_dict_in_generate=True,
|
75 |
+
pad_token_id=tokenizer.eos_token_id,
|
76 |
+
do_sample=True,
|
77 |
+
temperature=args.temperature,
|
78 |
+
top_k=args.topk,
|
79 |
+
top_p=args.topp,
|
80 |
+
repetition_penalty=args.repetition_penalty,
|
81 |
+
)
|
82 |
+
out = fn()
|
83 |
+
if args.prompt is not None:
|
84 |
+
print(tokenizer.batch_decode(out.sequences.tolist()))
|
85 |
+
|
86 |
+
torch.cuda.synchronize()
|
87 |
+
start = time.time()
|
88 |
+
for _ in range(repeats):
|
89 |
+
fn()
|
90 |
+
torch.cuda.synchronize()
|
91 |
+
print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
|
92 |
+
print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
|
mamba/build/lib/mamba_ssm/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "2.2.2"
|
2 |
+
|
3 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
4 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
5 |
+
from mamba_ssm.modules.mamba2 import Mamba2
|
6 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
mamba/build/lib/mamba_ssm/distributed/__init__.py
ADDED
File without changes
|
mamba/build/lib/mamba_ssm/distributed/distributed_utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
from torch.distributed import ProcessGroup
|
6 |
+
|
7 |
+
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
8 |
+
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
9 |
+
# version of PyTorch. The following 4 lines are for backward compatibility with
|
10 |
+
# older PyTorch.
|
11 |
+
if "all_gather_into_tensor" not in dir(torch.distributed):
|
12 |
+
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
13 |
+
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
14 |
+
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
15 |
+
|
16 |
+
|
17 |
+
# Raw operation, does not support autograd, but does support async
|
18 |
+
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
19 |
+
world_size = torch.distributed.get_world_size(process_group)
|
20 |
+
output = torch.empty(
|
21 |
+
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
22 |
+
)
|
23 |
+
handle = torch.distributed.all_gather_into_tensor(
|
24 |
+
output, input_.contiguous(), group=process_group, async_op=async_op
|
25 |
+
)
|
26 |
+
return output, handle
|
27 |
+
|
28 |
+
|
29 |
+
# Raw operation, does not support autograd, but does support async
|
30 |
+
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
31 |
+
world_size = torch.distributed.get_world_size(process_group)
|
32 |
+
assert input_.shape[0] % world_size == 0
|
33 |
+
output = torch.empty(
|
34 |
+
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
35 |
+
)
|
36 |
+
handle = torch.distributed.reduce_scatter_tensor(
|
37 |
+
output, input_.contiguous(), group=process_group, async_op=async_op
|
38 |
+
)
|
39 |
+
return output, handle
|
40 |
+
|
41 |
+
|
42 |
+
# Raw operation, does not support autograd, but does support async
|
43 |
+
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
44 |
+
input_ = input_.contiguous()
|
45 |
+
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
|
46 |
+
return input_, handle
|
47 |
+
|
48 |
+
|
49 |
+
class AllGatherFunc(torch.autograd.Function):
|
50 |
+
"""Gather the input from sequence parallel region and concatenate."""
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
54 |
+
ctx.process_group = process_group
|
55 |
+
output, _ = all_gather_raw(input_, process_group)
|
56 |
+
return output
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def backward(ctx, grad_output: Tensor):
|
60 |
+
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
|
61 |
+
return grad_input, None
|
62 |
+
|
63 |
+
|
64 |
+
# Supports autograd, but does not support async
|
65 |
+
all_gather = AllGatherFunc.apply
|
66 |
+
|
67 |
+
|
68 |
+
class ReduceScatterFunc(torch.autograd.Function):
|
69 |
+
"""Reduce scatter the input from the sequence parallel region and concatenate."""
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
73 |
+
ctx.process_group = process_group
|
74 |
+
output, _ = reduce_scatter_raw(input_, process_group)
|
75 |
+
return output
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def backward(ctx, grad_output: Tensor):
|
79 |
+
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
|
80 |
+
return grad_input, None
|
81 |
+
|
82 |
+
|
83 |
+
# Supports autograd, but does not support async
|
84 |
+
reduce_scatter = ReduceScatterFunc.apply
|
85 |
+
|
86 |
+
|
87 |
+
class AllReduceFunc(torch.autograd.Function):
|
88 |
+
"""Gather the input from sequence parallel region and concatenate."""
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
92 |
+
ctx.process_group = process_group
|
93 |
+
output, _ = all_reduce_raw(input_, process_group)
|
94 |
+
return output
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def backward(ctx, grad_output: Tensor):
|
98 |
+
return grad_output, None
|
99 |
+
|
100 |
+
|
101 |
+
# Supports autograd, but does not support async
|
102 |
+
all_reduce = AllReduceFunc.apply
|
103 |
+
|
104 |
+
|
105 |
+
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
|
106 |
+
# We want to iterate over parameters with _shared_params=True in the same order,
|
107 |
+
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
108 |
+
pamams_shared = {
|
109 |
+
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
|
110 |
+
}
|
111 |
+
for _, p in sorted(pamams_shared.items()):
|
112 |
+
with torch.no_grad():
|
113 |
+
# Broadcast needs src to be global rank, not group rank
|
114 |
+
torch.distributed.broadcast(
|
115 |
+
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
|
120 |
+
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
|
121 |
+
# We want to iterate over parameters with _sequence_parallel=True in the same order,
|
122 |
+
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
123 |
+
params_seqparallel = {
|
124 |
+
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
|
125 |
+
}
|
126 |
+
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
|
127 |
+
if grads:
|
128 |
+
with torch.no_grad():
|
129 |
+
coalesced = torch._utils._flatten_dense_tensors(grads)
|
130 |
+
torch.distributed.all_reduce(coalesced, group=process_group)
|
131 |
+
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
132 |
+
buf.copy_(synced)
|
133 |
+
|
134 |
+
|
135 |
+
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
|
136 |
+
"""Get the dim for the local rank derived from splitting dim on world_size processes.
|
137 |
+
|
138 |
+
The split may not be even across the world_size processes.
|
139 |
+
"""
|
140 |
+
multiple = dim // multiple_of
|
141 |
+
div = multiple // world_size
|
142 |
+
mod = multiple % world_size
|
143 |
+
local_multiple = div + int(local_rank < mod)
|
144 |
+
return local_multiple * multiple_of
|
mamba/build/lib/mamba_ssm/distributed/tensor_parallel.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import Tensor
|
9 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
10 |
+
from torch.distributed import ProcessGroup
|
11 |
+
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from mamba_ssm.distributed.distributed_utils import (
|
15 |
+
all_gather_raw,
|
16 |
+
all_reduce,
|
17 |
+
all_reduce_raw,
|
18 |
+
reduce_scatter,
|
19 |
+
reduce_scatter_raw,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class ParallelLinearFunc(torch.autograd.Function):
|
24 |
+
@staticmethod
|
25 |
+
@custom_fwd
|
26 |
+
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
|
27 |
+
"""
|
28 |
+
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
29 |
+
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
30 |
+
"""
|
31 |
+
ctx.compute_weight_gradient = weight.requires_grad
|
32 |
+
ctx.process_group = process_group
|
33 |
+
ctx.sequence_parallel = sequence_parallel
|
34 |
+
|
35 |
+
if torch.is_autocast_enabled():
|
36 |
+
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
37 |
+
x = x.contiguous()
|
38 |
+
if process_group is not None and sequence_parallel:
|
39 |
+
# We want to kick off the all_gather early, before weight dtype conversion
|
40 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
41 |
+
else:
|
42 |
+
total_x = x
|
43 |
+
|
44 |
+
if torch.is_autocast_enabled():
|
45 |
+
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
46 |
+
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
|
47 |
+
weight = weight.contiguous()
|
48 |
+
if process_group is not None and sequence_parallel:
|
49 |
+
handle_x.wait()
|
50 |
+
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
51 |
+
batch_dim = batch_shape.numel()
|
52 |
+
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
53 |
+
output = F.linear(total_x, weight, bias)
|
54 |
+
if ctx.compute_weight_gradient:
|
55 |
+
ctx.save_for_backward(x, weight)
|
56 |
+
else:
|
57 |
+
ctx.save_for_backward(weight)
|
58 |
+
return output
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
@custom_bwd
|
62 |
+
def backward(ctx, grad_output):
|
63 |
+
grad_output = grad_output.contiguous()
|
64 |
+
process_group = ctx.process_group
|
65 |
+
sequence_parallel = ctx.sequence_parallel
|
66 |
+
if ctx.compute_weight_gradient:
|
67 |
+
x, weight = ctx.saved_tensors
|
68 |
+
if process_group is not None and sequence_parallel:
|
69 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
70 |
+
else:
|
71 |
+
total_x = x
|
72 |
+
else:
|
73 |
+
(weight,) = ctx.saved_tensors
|
74 |
+
total_x = None
|
75 |
+
batch_shape = grad_output.shape[:-1]
|
76 |
+
batch_dim = batch_shape.numel()
|
77 |
+
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
78 |
+
if ctx.needs_input_grad[0]:
|
79 |
+
grad_input = F.linear(grad_output, weight.t())
|
80 |
+
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
81 |
+
if process_group is not None:
|
82 |
+
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
83 |
+
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
84 |
+
else:
|
85 |
+
grad_input = None
|
86 |
+
if ctx.needs_input_grad[1]:
|
87 |
+
assert ctx.compute_weight_gradient
|
88 |
+
if process_group is not None and sequence_parallel:
|
89 |
+
handle_x.wait()
|
90 |
+
grad_weight = torch.einsum(
|
91 |
+
"bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
grad_weight = None
|
95 |
+
grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
|
96 |
+
if process_group is not None and ctx.needs_input_grad[0]:
|
97 |
+
handle_grad_input.wait()
|
98 |
+
return grad_input, grad_weight, grad_bias, None, None
|
99 |
+
|
100 |
+
|
101 |
+
def parallel_linear_func(
|
102 |
+
x: Tensor,
|
103 |
+
weight: Tensor,
|
104 |
+
bias: Optional[Tensor] = None,
|
105 |
+
process_group: Optional[ProcessGroup] = None,
|
106 |
+
sequence_parallel: bool = True,
|
107 |
+
):
|
108 |
+
return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
|
109 |
+
|
110 |
+
|
111 |
+
class ColumnParallelLinear(nn.Linear):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
in_features: int,
|
115 |
+
out_features: int,
|
116 |
+
process_group: ProcessGroup,
|
117 |
+
bias: bool = True,
|
118 |
+
sequence_parallel=True,
|
119 |
+
multiple_of=1,
|
120 |
+
device=None,
|
121 |
+
dtype=None,
|
122 |
+
) -> None:
|
123 |
+
world_size = torch.distributed.get_world_size(process_group)
|
124 |
+
if out_features % multiple_of:
|
125 |
+
raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
|
126 |
+
multiple = out_features // multiple_of
|
127 |
+
# We want to split @multiple across world_size, but it could be an uneven split
|
128 |
+
div = multiple // world_size
|
129 |
+
mod = multiple % world_size
|
130 |
+
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
131 |
+
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
132 |
+
super().__init__(
|
133 |
+
in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
|
134 |
+
)
|
135 |
+
self.process_group = process_group
|
136 |
+
self.sequence_parallel = sequence_parallel
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
140 |
+
# we do an all_gather of x before doing the matmul.
|
141 |
+
# If not, then the input is already gathered.
|
142 |
+
return parallel_linear_func(
|
143 |
+
x,
|
144 |
+
self.weight,
|
145 |
+
self.bias,
|
146 |
+
process_group=self.process_group,
|
147 |
+
sequence_parallel=self.sequence_parallel,
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
class RowParallelLinear(nn.Linear):
|
152 |
+
def __init__(
|
153 |
+
self,
|
154 |
+
in_features: int,
|
155 |
+
out_features: int,
|
156 |
+
process_group: ProcessGroup,
|
157 |
+
bias: bool = True,
|
158 |
+
sequence_parallel=True,
|
159 |
+
multiple_of=1,
|
160 |
+
device=None,
|
161 |
+
dtype=None,
|
162 |
+
) -> None:
|
163 |
+
world_size = torch.distributed.get_world_size(process_group)
|
164 |
+
rank = torch.distributed.get_rank(process_group)
|
165 |
+
if in_features % multiple_of:
|
166 |
+
raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
|
167 |
+
multiple = in_features // multiple_of
|
168 |
+
# We want to split @multiple across world_size, but it could be an uneven split
|
169 |
+
div = multiple // world_size
|
170 |
+
mod = multiple % world_size
|
171 |
+
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
172 |
+
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
173 |
+
# Only rank 0 will have bias
|
174 |
+
super().__init__(
|
175 |
+
local_multiple * multiple_of,
|
176 |
+
out_features,
|
177 |
+
bias=bias and rank == 0,
|
178 |
+
device=device,
|
179 |
+
dtype=dtype,
|
180 |
+
)
|
181 |
+
self.process_group = process_group
|
182 |
+
self.sequence_parallel = sequence_parallel
|
183 |
+
|
184 |
+
def forward(self, x):
|
185 |
+
"""
|
186 |
+
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
187 |
+
a reduce_scatter of the result.
|
188 |
+
"""
|
189 |
+
out = parallel_linear_func(x, self.weight, self.bias)
|
190 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
191 |
+
return reduce_fn(out, self.process_group)
|
192 |
+
|
193 |
+
|
194 |
+
class VocabParallelEmbedding(nn.Embedding):
|
195 |
+
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
|
196 |
+
self.process_group = process_group
|
197 |
+
if process_group is not None:
|
198 |
+
world_size = torch.distributed.get_world_size(process_group)
|
199 |
+
if num_embeddings % world_size != 0:
|
200 |
+
raise ValueError(
|
201 |
+
f"num_embeddings ({num_embeddings}) must be divisible by "
|
202 |
+
f"world_size ({world_size})"
|
203 |
+
)
|
204 |
+
if world_size > 1 and padding_idx is not None:
|
205 |
+
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
206 |
+
else:
|
207 |
+
world_size = 1
|
208 |
+
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
|
209 |
+
|
210 |
+
def forward(self, input: Tensor) -> Tensor:
|
211 |
+
if self.process_group is None:
|
212 |
+
return super().forward(input)
|
213 |
+
else:
|
214 |
+
rank = torch.distributed.get_rank(self.process_group)
|
215 |
+
vocab_size = self.num_embeddings
|
216 |
+
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
|
217 |
+
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
218 |
+
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
219 |
+
input = input - vocab_start_index
|
220 |
+
input[input_ids_mask] = 0
|
221 |
+
embeddings = super().forward(input)
|
222 |
+
embeddings[input_ids_mask] = 0.0
|
223 |
+
return embeddings
|
224 |
+
|
225 |
+
|
226 |
+
class ColumnParallelEmbedding(nn.Embedding):
|
227 |
+
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
|
228 |
+
self.process_group = process_group
|
229 |
+
if process_group is not None:
|
230 |
+
world_size = torch.distributed.get_world_size(process_group)
|
231 |
+
if embedding_dim % world_size != 0:
|
232 |
+
raise ValueError(
|
233 |
+
f"embedding_dim ({embedding_dim}) must be divisible by "
|
234 |
+
f"world_size ({world_size})"
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
world_size = 1
|
238 |
+
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
239 |
+
|
240 |
+
|
241 |
+
class ParallelEmbeddings(nn.Module):
|
242 |
+
def __init__(
|
243 |
+
self,
|
244 |
+
embed_dim,
|
245 |
+
vocab_size,
|
246 |
+
max_position_embeddings,
|
247 |
+
process_group,
|
248 |
+
padding_idx=None,
|
249 |
+
sequence_parallel=True,
|
250 |
+
device=None,
|
251 |
+
dtype=None,
|
252 |
+
):
|
253 |
+
"""
|
254 |
+
If max_position_embeddings <= 0, there's no position embeddings
|
255 |
+
"""
|
256 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
257 |
+
super().__init__()
|
258 |
+
self.process_group = process_group
|
259 |
+
self.sequence_parallel = sequence_parallel
|
260 |
+
self.word_embeddings = VocabParallelEmbedding(
|
261 |
+
vocab_size,
|
262 |
+
embed_dim,
|
263 |
+
padding_idx=padding_idx,
|
264 |
+
process_group=process_group,
|
265 |
+
**factory_kwargs,
|
266 |
+
)
|
267 |
+
self.max_position_embeddings = max_position_embeddings
|
268 |
+
if self.max_position_embeddings > 0:
|
269 |
+
self.position_embeddings = ColumnParallelEmbedding(
|
270 |
+
max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
|
271 |
+
)
|
272 |
+
|
273 |
+
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
274 |
+
"""
|
275 |
+
input_ids: (batch, seqlen)
|
276 |
+
position_ids: (batch, seqlen)
|
277 |
+
"""
|
278 |
+
batch_size, seqlen = input_ids.shape
|
279 |
+
world_size = torch.distributed.get_world_size(self.process_group)
|
280 |
+
embeddings = self.word_embeddings(input_ids)
|
281 |
+
if self.max_position_embeddings > 0:
|
282 |
+
if position_ids is None:
|
283 |
+
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
284 |
+
position_embeddings = self.position_embeddings(position_ids)
|
285 |
+
if world_size <= 1:
|
286 |
+
embeddings = embeddings + position_embeddings
|
287 |
+
else:
|
288 |
+
partition_dim = self.position_embeddings.embedding_dim
|
289 |
+
rank = torch.distributed.get_rank(self.process_group)
|
290 |
+
embeddings[
|
291 |
+
..., rank * partition_dim : (rank + 1) * partition_dim
|
292 |
+
] += position_embeddings
|
293 |
+
if combine_batch_seqlen_dim:
|
294 |
+
embeddings = rearrange(embeddings, "b s d -> (b s) d")
|
295 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
296 |
+
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
mamba/build/lib/mamba_ssm/models/__init__.py
ADDED
File without changes
|
mamba/build/lib/mamba_ssm/models/config_mamba.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class MambaConfig:
|
6 |
+
|
7 |
+
d_model: int = 2560
|
8 |
+
d_intermediate: int = 0
|
9 |
+
n_layer: int = 64
|
10 |
+
vocab_size: int = 50277
|
11 |
+
ssm_cfg: dict = field(default_factory=dict)
|
12 |
+
attn_layer_idx: list = field(default_factory=list)
|
13 |
+
attn_cfg: dict = field(default_factory=dict)
|
14 |
+
rms_norm: bool = True
|
15 |
+
residual_in_fp32: bool = True
|
16 |
+
fused_add_norm: bool = True
|
17 |
+
pad_vocab_size_multiple: int = 8
|
18 |
+
tie_embeddings: bool = True
|
mamba/build/lib/mamba_ssm/models/mixer_seq_simple.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from functools import partial
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import copy
|
8 |
+
|
9 |
+
from collections import namedtuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from mamba_ssm.models.config_mamba import MambaConfig
|
15 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
16 |
+
from mamba_ssm.modules.mamba2 import Mamba2
|
17 |
+
from mamba_ssm.modules.mha import MHA
|
18 |
+
from mamba_ssm.modules.mlp import GatedMLP
|
19 |
+
from mamba_ssm.modules.block import Block
|
20 |
+
from mamba_ssm.utils.generation import GenerationMixin
|
21 |
+
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
22 |
+
|
23 |
+
try:
|
24 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
25 |
+
except ImportError:
|
26 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
27 |
+
|
28 |
+
|
29 |
+
def create_block(
|
30 |
+
d_model,
|
31 |
+
d_intermediate,
|
32 |
+
ssm_cfg=None,
|
33 |
+
attn_layer_idx=None,
|
34 |
+
attn_cfg=None,
|
35 |
+
norm_epsilon=1e-5,
|
36 |
+
rms_norm=False,
|
37 |
+
residual_in_fp32=False,
|
38 |
+
fused_add_norm=False,
|
39 |
+
layer_idx=None,
|
40 |
+
device=None,
|
41 |
+
dtype=None,
|
42 |
+
):
|
43 |
+
if ssm_cfg is None:
|
44 |
+
ssm_cfg = {}
|
45 |
+
if attn_layer_idx is None:
|
46 |
+
attn_layer_idx = []
|
47 |
+
if attn_cfg is None:
|
48 |
+
attn_cfg = {}
|
49 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
50 |
+
if layer_idx not in attn_layer_idx:
|
51 |
+
# Create a copy of the config to modify
|
52 |
+
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
|
53 |
+
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
|
54 |
+
if ssm_layer not in ["Mamba1", "Mamba2"]:
|
55 |
+
raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
|
56 |
+
mixer_cls = partial(
|
57 |
+
Mamba2 if ssm_layer == "Mamba2" else Mamba,
|
58 |
+
layer_idx=layer_idx,
|
59 |
+
**ssm_cfg,
|
60 |
+
**factory_kwargs
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
|
64 |
+
norm_cls = partial(
|
65 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
66 |
+
)
|
67 |
+
if d_intermediate == 0:
|
68 |
+
mlp_cls = nn.Identity
|
69 |
+
else:
|
70 |
+
mlp_cls = partial(
|
71 |
+
GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
|
72 |
+
)
|
73 |
+
block = Block(
|
74 |
+
d_model,
|
75 |
+
mixer_cls,
|
76 |
+
mlp_cls,
|
77 |
+
norm_cls=norm_cls,
|
78 |
+
fused_add_norm=fused_add_norm,
|
79 |
+
residual_in_fp32=residual_in_fp32,
|
80 |
+
)
|
81 |
+
block.layer_idx = layer_idx
|
82 |
+
return block
|
83 |
+
|
84 |
+
|
85 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
86 |
+
def _init_weights(
|
87 |
+
module,
|
88 |
+
n_layer,
|
89 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
90 |
+
rescale_prenorm_residual=True,
|
91 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
92 |
+
):
|
93 |
+
if isinstance(module, nn.Linear):
|
94 |
+
if module.bias is not None:
|
95 |
+
if not getattr(module.bias, "_no_reinit", False):
|
96 |
+
nn.init.zeros_(module.bias)
|
97 |
+
elif isinstance(module, nn.Embedding):
|
98 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
99 |
+
|
100 |
+
if rescale_prenorm_residual:
|
101 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
102 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
103 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
104 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
105 |
+
#
|
106 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
107 |
+
for name, p in module.named_parameters():
|
108 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
109 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
110 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
111 |
+
# We need to reinit p since this code could be called multiple times
|
112 |
+
# Having just p *= scale would repeatedly scale it down
|
113 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
114 |
+
with torch.no_grad():
|
115 |
+
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
116 |
+
|
117 |
+
|
118 |
+
class MixerModel(nn.Module):
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
d_model: int,
|
122 |
+
n_layer: int,
|
123 |
+
d_intermediate: int,
|
124 |
+
vocab_size: int,
|
125 |
+
ssm_cfg=None,
|
126 |
+
attn_layer_idx=None,
|
127 |
+
attn_cfg=None,
|
128 |
+
norm_epsilon: float = 1e-5,
|
129 |
+
rms_norm: bool = False,
|
130 |
+
initializer_cfg=None,
|
131 |
+
fused_add_norm=False,
|
132 |
+
residual_in_fp32=False,
|
133 |
+
device=None,
|
134 |
+
dtype=None,
|
135 |
+
) -> None:
|
136 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
137 |
+
super().__init__()
|
138 |
+
self.residual_in_fp32 = residual_in_fp32
|
139 |
+
|
140 |
+
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
141 |
+
|
142 |
+
# We change the order of residual and layer norm:
|
143 |
+
# Instead of LN -> Attn / MLP -> Add, we do:
|
144 |
+
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
145 |
+
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
146 |
+
# This is for performance reason: we can fuse add + layer_norm.
|
147 |
+
self.fused_add_norm = fused_add_norm
|
148 |
+
if self.fused_add_norm:
|
149 |
+
if layer_norm_fn is None or rms_norm_fn is None:
|
150 |
+
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
151 |
+
|
152 |
+
self.layers = nn.ModuleList(
|
153 |
+
[
|
154 |
+
create_block(
|
155 |
+
d_model,
|
156 |
+
d_intermediate=d_intermediate,
|
157 |
+
ssm_cfg=ssm_cfg,
|
158 |
+
attn_layer_idx=attn_layer_idx,
|
159 |
+
attn_cfg=attn_cfg,
|
160 |
+
norm_epsilon=norm_epsilon,
|
161 |
+
rms_norm=rms_norm,
|
162 |
+
residual_in_fp32=residual_in_fp32,
|
163 |
+
fused_add_norm=fused_add_norm,
|
164 |
+
layer_idx=i,
|
165 |
+
**factory_kwargs,
|
166 |
+
)
|
167 |
+
for i in range(n_layer)
|
168 |
+
]
|
169 |
+
)
|
170 |
+
|
171 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
172 |
+
d_model, eps=norm_epsilon, **factory_kwargs
|
173 |
+
)
|
174 |
+
|
175 |
+
self.apply(
|
176 |
+
partial(
|
177 |
+
_init_weights,
|
178 |
+
n_layer=n_layer,
|
179 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
180 |
+
n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
|
181 |
+
)
|
182 |
+
)
|
183 |
+
|
184 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
185 |
+
return {
|
186 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
187 |
+
for i, layer in enumerate(self.layers)
|
188 |
+
}
|
189 |
+
|
190 |
+
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
|
191 |
+
hidden_states = self.embedding(input_ids)
|
192 |
+
residual = None
|
193 |
+
for layer in self.layers:
|
194 |
+
hidden_states, residual = layer(
|
195 |
+
hidden_states, residual, inference_params=inference_params, **mixer_kwargs
|
196 |
+
)
|
197 |
+
if not self.fused_add_norm:
|
198 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
199 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
200 |
+
else:
|
201 |
+
# Set prenorm=False here since we don't need the residual
|
202 |
+
hidden_states = layer_norm_fn(
|
203 |
+
hidden_states,
|
204 |
+
self.norm_f.weight,
|
205 |
+
self.norm_f.bias,
|
206 |
+
eps=self.norm_f.eps,
|
207 |
+
residual=residual,
|
208 |
+
prenorm=False,
|
209 |
+
residual_in_fp32=self.residual_in_fp32,
|
210 |
+
is_rms_norm=isinstance(self.norm_f, RMSNorm)
|
211 |
+
)
|
212 |
+
return hidden_states
|
213 |
+
|
214 |
+
|
215 |
+
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
216 |
+
|
217 |
+
def __init__(
|
218 |
+
self,
|
219 |
+
config: MambaConfig,
|
220 |
+
initializer_cfg=None,
|
221 |
+
device=None,
|
222 |
+
dtype=None,
|
223 |
+
) -> None:
|
224 |
+
self.config = config
|
225 |
+
d_model = config.d_model
|
226 |
+
n_layer = config.n_layer
|
227 |
+
d_intermediate = config.d_intermediate
|
228 |
+
vocab_size = config.vocab_size
|
229 |
+
ssm_cfg = config.ssm_cfg
|
230 |
+
attn_layer_idx = config.attn_layer_idx
|
231 |
+
attn_cfg = config.attn_cfg
|
232 |
+
rms_norm = config.rms_norm
|
233 |
+
residual_in_fp32 = config.residual_in_fp32
|
234 |
+
fused_add_norm = config.fused_add_norm
|
235 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
236 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
237 |
+
|
238 |
+
super().__init__()
|
239 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
240 |
+
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
241 |
+
self.backbone = MixerModel(
|
242 |
+
d_model=d_model,
|
243 |
+
n_layer=n_layer,
|
244 |
+
d_intermediate=d_intermediate,
|
245 |
+
vocab_size=vocab_size,
|
246 |
+
ssm_cfg=ssm_cfg,
|
247 |
+
attn_layer_idx=attn_layer_idx,
|
248 |
+
attn_cfg=attn_cfg,
|
249 |
+
rms_norm=rms_norm,
|
250 |
+
initializer_cfg=initializer_cfg,
|
251 |
+
fused_add_norm=fused_add_norm,
|
252 |
+
residual_in_fp32=residual_in_fp32,
|
253 |
+
**factory_kwargs,
|
254 |
+
)
|
255 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
256 |
+
|
257 |
+
# Initialize weights and apply final processing
|
258 |
+
self.apply(
|
259 |
+
partial(
|
260 |
+
_init_weights,
|
261 |
+
n_layer=n_layer,
|
262 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
263 |
+
)
|
264 |
+
)
|
265 |
+
self.tie_weights()
|
266 |
+
|
267 |
+
def tie_weights(self):
|
268 |
+
if self.config.tie_embeddings:
|
269 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
270 |
+
|
271 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
272 |
+
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
273 |
+
|
274 |
+
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
|
275 |
+
"""
|
276 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
277 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
278 |
+
"""
|
279 |
+
hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
|
280 |
+
if num_last_tokens > 0:
|
281 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
282 |
+
lm_logits = self.lm_head(hidden_states)
|
283 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
284 |
+
return CausalLMOutput(logits=lm_logits)
|
285 |
+
|
286 |
+
@classmethod
|
287 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
288 |
+
config_data = load_config_hf(pretrained_model_name)
|
289 |
+
config = MambaConfig(**config_data)
|
290 |
+
model = cls(config, device=device, dtype=dtype, **kwargs)
|
291 |
+
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
|
292 |
+
return model
|
293 |
+
|
294 |
+
def save_pretrained(self, save_directory):
|
295 |
+
"""
|
296 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
297 |
+
Save the model and its configuration file to a directory.
|
298 |
+
"""
|
299 |
+
# Ensure save_directory exists
|
300 |
+
os.makedirs(save_directory, exist_ok=True)
|
301 |
+
|
302 |
+
# Save the model's state_dict
|
303 |
+
model_path = os.path.join(save_directory, 'pytorch_model.bin')
|
304 |
+
torch.save(self.state_dict(), model_path)
|
305 |
+
|
306 |
+
# Save the configuration of the model
|
307 |
+
config_path = os.path.join(save_directory, 'config.json')
|
308 |
+
with open(config_path, 'w') as f:
|
309 |
+
json.dump(self.config.__dict__, f, indent=4)
|
mamba/build/lib/mamba_ssm/modules/__init__.py
ADDED
File without changes
|
mamba/build/lib/mamba_ssm/modules/block.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
|
7 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
|
8 |
+
|
9 |
+
|
10 |
+
class Block(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
13 |
+
):
|
14 |
+
"""
|
15 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
16 |
+
|
17 |
+
This Block has a slightly different structure compared to a regular
|
18 |
+
prenorm Transformer block.
|
19 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
20 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
21 |
+
Here we have: Add -> LN -> Mixer, returning both
|
22 |
+
the hidden_states (output of the mixer) and the residual.
|
23 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
24 |
+
The residual needs to be provided (except for the very first block).
|
25 |
+
"""
|
26 |
+
super().__init__()
|
27 |
+
self.residual_in_fp32 = residual_in_fp32
|
28 |
+
self.fused_add_norm = fused_add_norm
|
29 |
+
self.norm = norm_cls(dim)
|
30 |
+
self.mixer = mixer_cls(dim)
|
31 |
+
if mlp_cls is not nn.Identity:
|
32 |
+
self.norm2 = norm_cls(dim)
|
33 |
+
self.mlp = mlp_cls(dim)
|
34 |
+
else:
|
35 |
+
self.mlp = None
|
36 |
+
if self.fused_add_norm:
|
37 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
38 |
+
assert isinstance(
|
39 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
40 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
41 |
+
|
42 |
+
def forward(
|
43 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs
|
44 |
+
):
|
45 |
+
r"""Pass the input through the encoder layer.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
hidden_states: the sequence to the encoder layer (required).
|
49 |
+
residual: hidden_states = Mixer(LN(residual))
|
50 |
+
"""
|
51 |
+
if not self.fused_add_norm:
|
52 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
53 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
54 |
+
if self.residual_in_fp32:
|
55 |
+
residual = residual.to(torch.float32)
|
56 |
+
else:
|
57 |
+
hidden_states, residual = layer_norm_fn(
|
58 |
+
hidden_states,
|
59 |
+
self.norm.weight,
|
60 |
+
self.norm.bias,
|
61 |
+
residual=residual,
|
62 |
+
prenorm=True,
|
63 |
+
residual_in_fp32=self.residual_in_fp32,
|
64 |
+
eps=self.norm.eps,
|
65 |
+
is_rms_norm=isinstance(self.norm, RMSNorm)
|
66 |
+
)
|
67 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
|
68 |
+
|
69 |
+
if self.mlp is not None:
|
70 |
+
if not self.fused_add_norm:
|
71 |
+
residual = hidden_states + residual
|
72 |
+
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
73 |
+
if self.residual_in_fp32:
|
74 |
+
residual = residual.to(torch.float32)
|
75 |
+
else:
|
76 |
+
hidden_states, residual = layer_norm_fn(
|
77 |
+
hidden_states,
|
78 |
+
self.norm2.weight,
|
79 |
+
self.norm2.bias,
|
80 |
+
residual=residual,
|
81 |
+
prenorm=True,
|
82 |
+
residual_in_fp32=self.residual_in_fp32,
|
83 |
+
eps=self.norm2.eps,
|
84 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
85 |
+
)
|
86 |
+
hidden_states = self.mlp(hidden_states)
|
87 |
+
|
88 |
+
return hidden_states, residual
|
89 |
+
|
90 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
91 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
mamba/build/lib/mamba_ssm/modules/mamba2.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
try:
|
12 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
13 |
+
except ImportError:
|
14 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
15 |
+
|
16 |
+
try:
|
17 |
+
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
|
18 |
+
except ImportError:
|
19 |
+
causal_conv1d_varlen_states = None
|
20 |
+
|
21 |
+
try:
|
22 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
23 |
+
except ImportError:
|
24 |
+
selective_state_update = None
|
25 |
+
|
26 |
+
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
|
27 |
+
|
28 |
+
from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
|
29 |
+
from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter
|
30 |
+
|
31 |
+
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
32 |
+
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
33 |
+
|
34 |
+
from huggingface_hub import PyTorchModelHubMixin
|
35 |
+
|
36 |
+
|
37 |
+
class Mamba2(nn.Module, PyTorchModelHubMixin):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
d_model,
|
41 |
+
d_state=128,
|
42 |
+
d_conv=4,
|
43 |
+
conv_init=None,
|
44 |
+
expand=2,
|
45 |
+
headdim=64,
|
46 |
+
d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
|
47 |
+
ngroups=1,
|
48 |
+
A_init_range=(1, 16),
|
49 |
+
D_has_hdim=False,
|
50 |
+
rmsnorm=True,
|
51 |
+
norm_before_gate=False,
|
52 |
+
dt_min=0.001,
|
53 |
+
dt_max=0.1,
|
54 |
+
dt_init_floor=1e-4,
|
55 |
+
dt_limit=(0.0, float("inf")),
|
56 |
+
bias=False,
|
57 |
+
conv_bias=True,
|
58 |
+
# Fused kernel and sharding options
|
59 |
+
chunk_size=256,
|
60 |
+
use_mem_eff_path=True,
|
61 |
+
layer_idx=None, # Absorb kwarg for general module
|
62 |
+
process_group=None,
|
63 |
+
sequence_parallel=True,
|
64 |
+
device=None,
|
65 |
+
dtype=None,
|
66 |
+
):
|
67 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
68 |
+
super().__init__()
|
69 |
+
self.d_model = d_model
|
70 |
+
self.d_state = d_state
|
71 |
+
self.d_conv = d_conv
|
72 |
+
self.conv_init = conv_init
|
73 |
+
self.expand = expand
|
74 |
+
self.process_group = process_group
|
75 |
+
self.sequence_parallel = sequence_parallel
|
76 |
+
self.world_size = 1 if process_group is None else process_group.size()
|
77 |
+
self.local_rank = 0 if process_group is None else process_group.rank()
|
78 |
+
self.d_inner = (self.expand * self.d_model) // self.world_size
|
79 |
+
assert self.d_inner * self.world_size == self.expand * self.d_model
|
80 |
+
self.headdim = headdim
|
81 |
+
self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
|
82 |
+
assert ngroups % self.world_size == 0
|
83 |
+
self.ngroups = ngroups // self.world_size
|
84 |
+
assert self.d_ssm % self.headdim == 0
|
85 |
+
self.nheads = self.d_ssm // self.headdim
|
86 |
+
self.D_has_hdim = D_has_hdim
|
87 |
+
self.rmsnorm = rmsnorm
|
88 |
+
self.norm_before_gate = norm_before_gate
|
89 |
+
self.dt_limit = dt_limit
|
90 |
+
self.activation = "silu"
|
91 |
+
self.chunk_size = chunk_size
|
92 |
+
self.use_mem_eff_path = use_mem_eff_path
|
93 |
+
self.layer_idx = layer_idx
|
94 |
+
|
95 |
+
# Order: [z, x, B, C, dt]
|
96 |
+
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
97 |
+
if self.process_group is None:
|
98 |
+
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
|
99 |
+
else:
|
100 |
+
self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,
|
101 |
+
process_group=self.process_group, sequence_parallel=self.sequence_parallel,
|
102 |
+
**factory_kwargs)
|
103 |
+
|
104 |
+
conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
|
105 |
+
self.conv1d = nn.Conv1d(
|
106 |
+
in_channels=conv_dim,
|
107 |
+
out_channels=conv_dim,
|
108 |
+
bias=conv_bias,
|
109 |
+
kernel_size=d_conv,
|
110 |
+
groups=conv_dim,
|
111 |
+
padding=d_conv - 1,
|
112 |
+
**factory_kwargs,
|
113 |
+
)
|
114 |
+
if self.conv_init is not None:
|
115 |
+
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
116 |
+
|
117 |
+
self.act = nn.SiLU()
|
118 |
+
|
119 |
+
# Initialize log dt bias
|
120 |
+
dt = torch.exp(
|
121 |
+
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
122 |
+
+ math.log(dt_min)
|
123 |
+
)
|
124 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
125 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
126 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
127 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
128 |
+
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
129 |
+
# name.endswith("bias") in param_grouping.py
|
130 |
+
self.dt_bias._no_weight_decay = True
|
131 |
+
|
132 |
+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
133 |
+
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
|
134 |
+
A_log = torch.log(A).to(dtype=dtype)
|
135 |
+
self.A_log = nn.Parameter(A_log)
|
136 |
+
self.A_log._no_weight_decay = True
|
137 |
+
|
138 |
+
# D "skip" parameter
|
139 |
+
self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))
|
140 |
+
self.D._no_weight_decay = True
|
141 |
+
|
142 |
+
if self.rmsnorm:
|
143 |
+
assert RMSNormGated is not None
|
144 |
+
self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
|
145 |
+
group_size=self.d_ssm // ngroups, **factory_kwargs)
|
146 |
+
|
147 |
+
if self.process_group is None:
|
148 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
149 |
+
else:
|
150 |
+
self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,
|
151 |
+
process_group=self.process_group, sequence_parallel=self.sequence_parallel,
|
152 |
+
**factory_kwargs)
|
153 |
+
|
154 |
+
def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
|
155 |
+
"""
|
156 |
+
u: (batch, seqlen, hidden_dim) if seqlen=None.
|
157 |
+
If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
|
158 |
+
split u during sequence parallel, we split the batch * seqlen dimension
|
159 |
+
(in case batch is small).
|
160 |
+
Returns: same shape as u
|
161 |
+
"""
|
162 |
+
seqlen_og = seqlen
|
163 |
+
if seqlen is None:
|
164 |
+
batch, seqlen, dim = u.shape
|
165 |
+
else:
|
166 |
+
batch_seqlen, dim = u.shape
|
167 |
+
batch = batch_seqlen // seqlen
|
168 |
+
|
169 |
+
conv_state, ssm_state = None, None
|
170 |
+
if inference_params is not None:
|
171 |
+
inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
|
172 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
|
173 |
+
if inference_params.seqlen_offset > 0:
|
174 |
+
# The states are updated inplace
|
175 |
+
out, _, _ = self.step(u, conv_state, ssm_state)
|
176 |
+
return out
|
177 |
+
|
178 |
+
zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
|
179 |
+
if seqlen_og is not None:
|
180 |
+
zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
|
181 |
+
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
182 |
+
A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
|
183 |
+
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
184 |
+
if self.use_mem_eff_path and inference_params is None:
|
185 |
+
out = mamba_split_conv1d_scan_combined(
|
186 |
+
zxbcdt,
|
187 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
188 |
+
self.conv1d.bias,
|
189 |
+
self.dt_bias,
|
190 |
+
A,
|
191 |
+
D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
|
192 |
+
chunk_size=self.chunk_size,
|
193 |
+
seq_idx=seq_idx,
|
194 |
+
activation=self.activation,
|
195 |
+
rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
|
196 |
+
rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
|
197 |
+
outproj_weight=self.out_proj.weight,
|
198 |
+
outproj_bias=self.out_proj.bias,
|
199 |
+
headdim=None if self.D_has_hdim else self.headdim,
|
200 |
+
ngroups=self.ngroups,
|
201 |
+
norm_before_gate=self.norm_before_gate,
|
202 |
+
**dt_limit_kwargs,
|
203 |
+
)
|
204 |
+
if seqlen_og is not None:
|
205 |
+
out = rearrange(out, "b l d -> (b l) d")
|
206 |
+
if self.process_group is not None:
|
207 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
208 |
+
out = reduce_fn(out, self.process_group)
|
209 |
+
else:
|
210 |
+
d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
|
211 |
+
z0, x0, z, xBC, dt = torch.split(
|
212 |
+
zxbcdt,
|
213 |
+
[d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
|
214 |
+
dim=-1
|
215 |
+
)
|
216 |
+
if conv_state is not None:
|
217 |
+
if cu_seqlens is None:
|
218 |
+
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
219 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
220 |
+
xBC_t = rearrange(xBC, "b l d -> b d l")
|
221 |
+
conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
|
222 |
+
else:
|
223 |
+
assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
|
224 |
+
assert batch == 1, "varlen inference only supports batch dimension 1"
|
225 |
+
conv_varlen_states = causal_conv1d_varlen_states(
|
226 |
+
xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
|
227 |
+
)
|
228 |
+
conv_state.copy_(conv_varlen_states)
|
229 |
+
assert self.activation in ["silu", "swish"]
|
230 |
+
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
231 |
+
assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
|
232 |
+
xBC = self.act(
|
233 |
+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
|
234 |
+
) # (B, L, self.d_ssm + 2 * ngroups * d_state)
|
235 |
+
else:
|
236 |
+
xBC = causal_conv1d_fn(
|
237 |
+
xBC.transpose(1, 2),
|
238 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
239 |
+
bias=self.conv1d.bias,
|
240 |
+
activation=self.activation,
|
241 |
+
seq_idx=seq_idx,
|
242 |
+
).transpose(1, 2)
|
243 |
+
x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
|
244 |
+
y = mamba_chunk_scan_combined(
|
245 |
+
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
246 |
+
dt,
|
247 |
+
A,
|
248 |
+
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
249 |
+
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
250 |
+
chunk_size=self.chunk_size,
|
251 |
+
D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
|
252 |
+
z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
|
253 |
+
dt_bias=self.dt_bias,
|
254 |
+
dt_softplus=True,
|
255 |
+
seq_idx=seq_idx,
|
256 |
+
cu_seqlens=cu_seqlens,
|
257 |
+
**dt_limit_kwargs,
|
258 |
+
return_final_states=ssm_state is not None,
|
259 |
+
return_varlen_states=cu_seqlens is not None and inference_params is not None,
|
260 |
+
)
|
261 |
+
if ssm_state is not None:
|
262 |
+
y, last_state, *rest = y
|
263 |
+
if cu_seqlens is None:
|
264 |
+
ssm_state.copy_(last_state)
|
265 |
+
else:
|
266 |
+
varlen_states = rest[0]
|
267 |
+
ssm_state.copy_(varlen_states)
|
268 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
269 |
+
if self.rmsnorm:
|
270 |
+
y = self.norm(y, z)
|
271 |
+
if d_mlp > 0:
|
272 |
+
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
273 |
+
if seqlen_og is not None:
|
274 |
+
y = rearrange(y, "b l d -> (b l) d")
|
275 |
+
out = self.out_proj(y)
|
276 |
+
return out
|
277 |
+
|
278 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
279 |
+
dtype = hidden_states.dtype
|
280 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
281 |
+
zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
282 |
+
d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
|
283 |
+
z0, x0, z, xBC, dt = torch.split(
|
284 |
+
zxbcdt,
|
285 |
+
[d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
|
286 |
+
dim=-1
|
287 |
+
)
|
288 |
+
|
289 |
+
# Conv step
|
290 |
+
if causal_conv1d_update is None:
|
291 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
292 |
+
conv_state[:, :, -1] = xBC
|
293 |
+
xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
294 |
+
if self.conv1d.bias is not None:
|
295 |
+
xBC = xBC + self.conv1d.bias
|
296 |
+
xBC = self.act(xBC).to(dtype=dtype)
|
297 |
+
else:
|
298 |
+
xBC = causal_conv1d_update(
|
299 |
+
xBC,
|
300 |
+
conv_state,
|
301 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
302 |
+
self.conv1d.bias,
|
303 |
+
self.activation,
|
304 |
+
)
|
305 |
+
|
306 |
+
x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
|
307 |
+
A = -torch.exp(self.A_log.float()) # (nheads,)
|
308 |
+
|
309 |
+
# SSM step
|
310 |
+
if selective_state_update is None:
|
311 |
+
assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
|
312 |
+
# Discretize A and B
|
313 |
+
dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
|
314 |
+
dA = torch.exp(dt * A) # (batch, nheads)
|
315 |
+
x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
|
316 |
+
dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
|
317 |
+
ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
|
318 |
+
y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
|
319 |
+
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
|
320 |
+
y = rearrange(y, "b h p -> b (h p)")
|
321 |
+
if not self.rmsnorm:
|
322 |
+
y = y * self.act(z) # (B D)
|
323 |
+
else:
|
324 |
+
A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
|
325 |
+
dt = repeat(dt, "b h -> b h p", p=self.headdim)
|
326 |
+
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
|
327 |
+
D = repeat(self.D, "h -> h p", p=self.headdim)
|
328 |
+
B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
|
329 |
+
C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
|
330 |
+
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
|
331 |
+
if not self.rmsnorm:
|
332 |
+
z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
|
333 |
+
y = selective_state_update(
|
334 |
+
ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
|
335 |
+
dt_bias=dt_bias, dt_softplus=True
|
336 |
+
)
|
337 |
+
y = rearrange(y, "b h p -> b (h p)")
|
338 |
+
if self.rmsnorm:
|
339 |
+
y = self.norm(y, z)
|
340 |
+
if d_mlp > 0:
|
341 |
+
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
342 |
+
out = self.out_proj(y)
|
343 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
344 |
+
|
345 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
346 |
+
device = self.out_proj.weight.device
|
347 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
348 |
+
conv_state = torch.zeros(
|
349 |
+
batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype
|
350 |
+
).transpose(1, 2)
|
351 |
+
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
|
352 |
+
ssm_state = torch.zeros(
|
353 |
+
batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
|
354 |
+
)
|
355 |
+
return conv_state, ssm_state
|
356 |
+
|
357 |
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
358 |
+
assert self.layer_idx is not None
|
359 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
360 |
+
batch_shape = (batch_size,)
|
361 |
+
conv_state = torch.zeros(
|
362 |
+
batch_size,
|
363 |
+
self.d_conv,
|
364 |
+
self.conv1d.weight.shape[0],
|
365 |
+
device=self.conv1d.weight.device,
|
366 |
+
dtype=self.conv1d.weight.dtype,
|
367 |
+
).transpose(1, 2)
|
368 |
+
ssm_state = torch.zeros(
|
369 |
+
batch_size,
|
370 |
+
self.nheads,
|
371 |
+
self.headdim,
|
372 |
+
self.d_state,
|
373 |
+
device=self.in_proj.weight.device,
|
374 |
+
dtype=self.in_proj.weight.dtype,
|
375 |
+
)
|
376 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
377 |
+
else:
|
378 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
379 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
380 |
+
if initialize_states:
|
381 |
+
conv_state.zero_()
|
382 |
+
ssm_state.zero_()
|
383 |
+
return conv_state, ssm_state
|
mamba/build/lib/mamba_ssm/modules/mamba2_simple.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
|
10 |
+
try:
|
11 |
+
from causal_conv1d import causal_conv1d_fn
|
12 |
+
except ImportError:
|
13 |
+
causal_conv1d_fn = None
|
14 |
+
|
15 |
+
try:
|
16 |
+
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
|
17 |
+
except ImportError:
|
18 |
+
RMSNormGated, LayerNorm = None, None
|
19 |
+
|
20 |
+
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
21 |
+
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
22 |
+
|
23 |
+
|
24 |
+
class Mamba2Simple(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
d_model,
|
28 |
+
d_state=64,
|
29 |
+
d_conv=4,
|
30 |
+
conv_init=None,
|
31 |
+
expand=2,
|
32 |
+
headdim=128,
|
33 |
+
ngroups=1,
|
34 |
+
A_init_range=(1, 16),
|
35 |
+
dt_min=0.001,
|
36 |
+
dt_max=0.1,
|
37 |
+
dt_init_floor=1e-4,
|
38 |
+
dt_limit=(0.0, float("inf")),
|
39 |
+
learnable_init_states=False,
|
40 |
+
activation="swish",
|
41 |
+
bias=False,
|
42 |
+
conv_bias=True,
|
43 |
+
# Fused kernel and sharding options
|
44 |
+
chunk_size=256,
|
45 |
+
use_mem_eff_path=True,
|
46 |
+
layer_idx=None, # Absorb kwarg for general module
|
47 |
+
device=None,
|
48 |
+
dtype=None,
|
49 |
+
):
|
50 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
51 |
+
super().__init__()
|
52 |
+
self.d_model = d_model
|
53 |
+
self.d_state = d_state
|
54 |
+
self.d_conv = d_conv
|
55 |
+
self.conv_init = conv_init
|
56 |
+
self.expand = expand
|
57 |
+
self.d_inner = self.expand * self.d_model
|
58 |
+
self.headdim = headdim
|
59 |
+
self.ngroups = ngroups
|
60 |
+
assert self.d_inner % self.headdim == 0
|
61 |
+
self.nheads = self.d_inner // self.headdim
|
62 |
+
self.dt_limit = dt_limit
|
63 |
+
self.learnable_init_states = learnable_init_states
|
64 |
+
self.activation = activation
|
65 |
+
self.chunk_size = chunk_size
|
66 |
+
self.use_mem_eff_path = use_mem_eff_path
|
67 |
+
self.layer_idx = layer_idx
|
68 |
+
|
69 |
+
# Order: [z, x, B, C, dt]
|
70 |
+
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
71 |
+
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
|
72 |
+
|
73 |
+
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
|
74 |
+
self.conv1d = nn.Conv1d(
|
75 |
+
in_channels=conv_dim,
|
76 |
+
out_channels=conv_dim,
|
77 |
+
bias=conv_bias,
|
78 |
+
kernel_size=d_conv,
|
79 |
+
groups=conv_dim,
|
80 |
+
padding=d_conv - 1,
|
81 |
+
**factory_kwargs,
|
82 |
+
)
|
83 |
+
if self.conv_init is not None:
|
84 |
+
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
85 |
+
# self.conv1d.weight._no_weight_decay = True
|
86 |
+
|
87 |
+
if self.learnable_init_states:
|
88 |
+
self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))
|
89 |
+
self.init_states._no_weight_decay = True
|
90 |
+
|
91 |
+
self.act = nn.SiLU()
|
92 |
+
|
93 |
+
# Initialize log dt bias
|
94 |
+
dt = torch.exp(
|
95 |
+
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
96 |
+
+ math.log(dt_min)
|
97 |
+
)
|
98 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
99 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
100 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
101 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
102 |
+
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
103 |
+
# name.endswith("bias") in param_grouping.py
|
104 |
+
self.dt_bias._no_weight_decay = True
|
105 |
+
|
106 |
+
# A parameter
|
107 |
+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
108 |
+
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
|
109 |
+
A_log = torch.log(A).to(dtype=dtype)
|
110 |
+
self.A_log = nn.Parameter(A_log)
|
111 |
+
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
|
112 |
+
self.A_log._no_weight_decay = True
|
113 |
+
|
114 |
+
# D "skip" parameter
|
115 |
+
self.D = nn.Parameter(torch.ones(self.nheads, device=device))
|
116 |
+
self.D._no_weight_decay = True
|
117 |
+
|
118 |
+
# Extra normalization layer right before output projection
|
119 |
+
assert RMSNormGated is not None
|
120 |
+
self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)
|
121 |
+
|
122 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
123 |
+
|
124 |
+
def forward(self, u, seq_idx=None):
|
125 |
+
"""
|
126 |
+
u: (B, L, D)
|
127 |
+
Returns: same shape as u
|
128 |
+
"""
|
129 |
+
batch, seqlen, dim = u.shape
|
130 |
+
|
131 |
+
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
|
132 |
+
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
|
133 |
+
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None
|
134 |
+
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
135 |
+
|
136 |
+
if self.use_mem_eff_path:
|
137 |
+
# Fully fused path
|
138 |
+
out = mamba_split_conv1d_scan_combined(
|
139 |
+
zxbcdt,
|
140 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
141 |
+
self.conv1d.bias,
|
142 |
+
self.dt_bias,
|
143 |
+
A,
|
144 |
+
D=self.D,
|
145 |
+
chunk_size=self.chunk_size,
|
146 |
+
seq_idx=seq_idx,
|
147 |
+
activation=self.activation,
|
148 |
+
rmsnorm_weight=self.norm.weight,
|
149 |
+
rmsnorm_eps=self.norm.eps,
|
150 |
+
outproj_weight=self.out_proj.weight,
|
151 |
+
outproj_bias=self.out_proj.bias,
|
152 |
+
headdim=self.headdim,
|
153 |
+
ngroups=self.ngroups,
|
154 |
+
norm_before_gate=False,
|
155 |
+
initial_states=initial_states,
|
156 |
+
**dt_limit_kwargs,
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
z, xBC, dt = torch.split(
|
160 |
+
zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1
|
161 |
+
)
|
162 |
+
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
|
163 |
+
assert self.activation in ["silu", "swish"]
|
164 |
+
|
165 |
+
# 1D Convolution
|
166 |
+
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
167 |
+
xBC = self.act(
|
168 |
+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
|
169 |
+
) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
170 |
+
xBC = xBC[:, :seqlen, :]
|
171 |
+
else:
|
172 |
+
xBC = causal_conv1d_fn(
|
173 |
+
x=xBC.transpose(1, 2),
|
174 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
175 |
+
bias=self.conv1d.bias,
|
176 |
+
activation=self.activation,
|
177 |
+
).transpose(1, 2)
|
178 |
+
|
179 |
+
# Split into 3 main branches: X, B, C
|
180 |
+
# These correspond to V, K, Q respectively in the SSM/attention duality
|
181 |
+
x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
|
182 |
+
y = mamba_chunk_scan_combined(
|
183 |
+
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
184 |
+
dt,
|
185 |
+
A,
|
186 |
+
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
187 |
+
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
188 |
+
chunk_size=self.chunk_size,
|
189 |
+
D=self.D,
|
190 |
+
z=None,
|
191 |
+
seq_idx=seq_idx,
|
192 |
+
initial_states=initial_states,
|
193 |
+
**dt_limit_kwargs,
|
194 |
+
)
|
195 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
196 |
+
|
197 |
+
# Multiply "gate" branch and apply extra normalization layer
|
198 |
+
y = self.norm(y, z)
|
199 |
+
out = self.out_proj(y)
|
200 |
+
return out
|
mamba/build/lib/mamba_ssm/modules/mamba_simple.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
|
13 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
14 |
+
|
15 |
+
try:
|
16 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
17 |
+
except ImportError:
|
18 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
19 |
+
|
20 |
+
try:
|
21 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
22 |
+
except ImportError:
|
23 |
+
selective_state_update = None
|
24 |
+
|
25 |
+
try:
|
26 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
27 |
+
except ImportError:
|
28 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
29 |
+
|
30 |
+
|
31 |
+
class Mamba(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
d_model,
|
35 |
+
d_state=16,
|
36 |
+
d_conv=4,
|
37 |
+
expand=2,
|
38 |
+
dt_rank="auto",
|
39 |
+
dt_min=0.001,
|
40 |
+
dt_max=0.1,
|
41 |
+
dt_init="random",
|
42 |
+
dt_scale=1.0,
|
43 |
+
dt_init_floor=1e-4,
|
44 |
+
conv_bias=True,
|
45 |
+
bias=False,
|
46 |
+
use_fast_path=True, # Fused kernel options
|
47 |
+
layer_idx=None,
|
48 |
+
device=None,
|
49 |
+
dtype=None,
|
50 |
+
):
|
51 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
52 |
+
super().__init__()
|
53 |
+
self.d_model = d_model
|
54 |
+
self.d_state = d_state
|
55 |
+
self.d_conv = d_conv
|
56 |
+
self.expand = expand
|
57 |
+
self.d_inner = int(self.expand * self.d_model)
|
58 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
59 |
+
self.use_fast_path = use_fast_path
|
60 |
+
self.layer_idx = layer_idx
|
61 |
+
|
62 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
63 |
+
|
64 |
+
self.conv1d = nn.Conv1d(
|
65 |
+
in_channels=self.d_inner,
|
66 |
+
out_channels=self.d_inner,
|
67 |
+
bias=conv_bias,
|
68 |
+
kernel_size=d_conv,
|
69 |
+
groups=self.d_inner,
|
70 |
+
padding=d_conv - 1,
|
71 |
+
**factory_kwargs,
|
72 |
+
)
|
73 |
+
|
74 |
+
self.activation = "silu"
|
75 |
+
self.act = nn.SiLU()
|
76 |
+
|
77 |
+
self.x_proj = nn.Linear(
|
78 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
79 |
+
)
|
80 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
81 |
+
|
82 |
+
# Initialize special dt projection to preserve variance at initialization
|
83 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
84 |
+
if dt_init == "constant":
|
85 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
86 |
+
elif dt_init == "random":
|
87 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
88 |
+
else:
|
89 |
+
raise NotImplementedError
|
90 |
+
|
91 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
92 |
+
dt = torch.exp(
|
93 |
+
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
94 |
+
+ math.log(dt_min)
|
95 |
+
).clamp(min=dt_init_floor)
|
96 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
97 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
98 |
+
with torch.no_grad():
|
99 |
+
self.dt_proj.bias.copy_(inv_dt)
|
100 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
101 |
+
self.dt_proj.bias._no_reinit = True
|
102 |
+
|
103 |
+
# S4D real initialization
|
104 |
+
A = repeat(
|
105 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
106 |
+
"n -> d n",
|
107 |
+
d=self.d_inner,
|
108 |
+
).contiguous()
|
109 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
110 |
+
self.A_log = nn.Parameter(A_log)
|
111 |
+
self.A_log._no_weight_decay = True
|
112 |
+
|
113 |
+
# D "skip" parameter
|
114 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
115 |
+
self.D._no_weight_decay = True
|
116 |
+
|
117 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
118 |
+
|
119 |
+
def forward(self, hidden_states, inference_params=None):
|
120 |
+
"""
|
121 |
+
hidden_states: (B, L, D)
|
122 |
+
Returns: same shape as hidden_states
|
123 |
+
"""
|
124 |
+
batch, seqlen, dim = hidden_states.shape
|
125 |
+
|
126 |
+
conv_state, ssm_state = None, None
|
127 |
+
if inference_params is not None:
|
128 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
129 |
+
if inference_params.seqlen_offset > 0:
|
130 |
+
# The states are updated inplace
|
131 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
132 |
+
return out
|
133 |
+
|
134 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
135 |
+
xz = rearrange(
|
136 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
137 |
+
"d (b l) -> b d l",
|
138 |
+
l=seqlen,
|
139 |
+
)
|
140 |
+
if self.in_proj.bias is not None:
|
141 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
142 |
+
|
143 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
144 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
145 |
+
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
|
146 |
+
out = mamba_inner_fn(
|
147 |
+
xz,
|
148 |
+
self.conv1d.weight,
|
149 |
+
self.conv1d.bias,
|
150 |
+
self.x_proj.weight,
|
151 |
+
self.dt_proj.weight,
|
152 |
+
self.out_proj.weight,
|
153 |
+
self.out_proj.bias,
|
154 |
+
A,
|
155 |
+
None, # input-dependent B
|
156 |
+
None, # input-dependent C
|
157 |
+
self.D.float(),
|
158 |
+
delta_bias=self.dt_proj.bias.float(),
|
159 |
+
delta_softplus=True,
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
x, z = xz.chunk(2, dim=1)
|
163 |
+
# Compute short convolution
|
164 |
+
if conv_state is not None:
|
165 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
166 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
167 |
+
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
|
168 |
+
if causal_conv1d_fn is None:
|
169 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
170 |
+
else:
|
171 |
+
assert self.activation in ["silu", "swish"]
|
172 |
+
x = causal_conv1d_fn(
|
173 |
+
x=x,
|
174 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
175 |
+
bias=self.conv1d.bias,
|
176 |
+
activation=self.activation,
|
177 |
+
)
|
178 |
+
|
179 |
+
# We're careful here about the layout, to avoid extra transposes.
|
180 |
+
# We want dt to have d as the slowest moving dimension
|
181 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
182 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
183 |
+
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
184 |
+
dt = self.dt_proj.weight @ dt.t()
|
185 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
186 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
187 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
188 |
+
assert self.activation in ["silu", "swish"]
|
189 |
+
y = selective_scan_fn(
|
190 |
+
x,
|
191 |
+
dt,
|
192 |
+
A,
|
193 |
+
B,
|
194 |
+
C,
|
195 |
+
self.D.float(),
|
196 |
+
z=z,
|
197 |
+
delta_bias=self.dt_proj.bias.float(),
|
198 |
+
delta_softplus=True,
|
199 |
+
return_last_state=ssm_state is not None,
|
200 |
+
)
|
201 |
+
if ssm_state is not None:
|
202 |
+
y, last_state = y
|
203 |
+
ssm_state.copy_(last_state)
|
204 |
+
y = rearrange(y, "b d l -> b l d")
|
205 |
+
out = self.out_proj(y)
|
206 |
+
return out
|
207 |
+
|
208 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
209 |
+
dtype = hidden_states.dtype
|
210 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
211 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
212 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
213 |
+
|
214 |
+
# Conv step
|
215 |
+
if causal_conv1d_update is None:
|
216 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
217 |
+
conv_state[:, :, -1] = x
|
218 |
+
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
219 |
+
if self.conv1d.bias is not None:
|
220 |
+
x = x + self.conv1d.bias
|
221 |
+
x = self.act(x).to(dtype=dtype)
|
222 |
+
else:
|
223 |
+
x = causal_conv1d_update(
|
224 |
+
x,
|
225 |
+
conv_state,
|
226 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
227 |
+
self.conv1d.bias,
|
228 |
+
self.activation,
|
229 |
+
)
|
230 |
+
|
231 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
232 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
233 |
+
# Don't add dt_bias here
|
234 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
235 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
236 |
+
|
237 |
+
# SSM step
|
238 |
+
if selective_state_update is None:
|
239 |
+
# Discretize A and B
|
240 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
241 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
242 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
243 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
244 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
245 |
+
y = y + self.D.to(dtype) * x
|
246 |
+
y = y * self.act(z) # (B D)
|
247 |
+
else:
|
248 |
+
y = selective_state_update(
|
249 |
+
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
250 |
+
)
|
251 |
+
|
252 |
+
out = self.out_proj(y)
|
253 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
254 |
+
|
255 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
256 |
+
device = self.out_proj.weight.device
|
257 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
258 |
+
conv_state = torch.zeros(
|
259 |
+
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
|
260 |
+
)
|
261 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
262 |
+
# ssm_dtype = torch.float32
|
263 |
+
ssm_state = torch.zeros(
|
264 |
+
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
265 |
+
)
|
266 |
+
return conv_state, ssm_state
|
267 |
+
|
268 |
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
269 |
+
assert self.layer_idx is not None
|
270 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
271 |
+
batch_shape = (batch_size,)
|
272 |
+
conv_state = torch.zeros(
|
273 |
+
batch_size,
|
274 |
+
self.d_model * self.expand,
|
275 |
+
self.d_conv,
|
276 |
+
device=self.conv1d.weight.device,
|
277 |
+
dtype=self.conv1d.weight.dtype,
|
278 |
+
)
|
279 |
+
ssm_state = torch.zeros(
|
280 |
+
batch_size,
|
281 |
+
self.d_model * self.expand,
|
282 |
+
self.d_state,
|
283 |
+
device=self.dt_proj.weight.device,
|
284 |
+
dtype=self.dt_proj.weight.dtype,
|
285 |
+
# dtype=torch.float32,
|
286 |
+
)
|
287 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
288 |
+
else:
|
289 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
290 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
291 |
+
if initialize_states:
|
292 |
+
conv_state.zero_()
|
293 |
+
ssm_state.zero_()
|
294 |
+
return conv_state, ssm_state
|
mamba/build/lib/mamba_ssm/modules/mha.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
try:
|
11 |
+
from flash_attn import flash_attn_with_kvcache
|
12 |
+
except ImportError:
|
13 |
+
flash_attn_with_kvcache = None
|
14 |
+
|
15 |
+
try:
|
16 |
+
from flash_attn.layers.rotary import RotaryEmbedding
|
17 |
+
except ImportError:
|
18 |
+
RotaryEmbedding = None
|
19 |
+
|
20 |
+
try:
|
21 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
22 |
+
except ImportError:
|
23 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
24 |
+
|
25 |
+
|
26 |
+
def _update_kv_cache(kv, inference_params, layer_idx):
|
27 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
28 |
+
# Pre-allocate memory for key-values for inference.
|
29 |
+
num_heads, head_dim = kv.shape[-2:]
|
30 |
+
assert layer_idx in inference_params.key_value_memory_dict
|
31 |
+
kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
|
32 |
+
# Adjust key and value for inference
|
33 |
+
batch_start = inference_params.batch_size_offset
|
34 |
+
batch_end = batch_start + kv.shape[0]
|
35 |
+
sequence_start = inference_params.seqlen_offset
|
36 |
+
sequence_end = sequence_start + kv.shape[1]
|
37 |
+
assert batch_end <= kv_cache.shape[0]
|
38 |
+
assert sequence_end <= kv_cache.shape[1]
|
39 |
+
assert kv_cache is not None
|
40 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
41 |
+
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
42 |
+
|
43 |
+
|
44 |
+
class MHA(nn.Module):
|
45 |
+
"""Multi-head self-attention and cross-attention"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
embed_dim,
|
50 |
+
num_heads,
|
51 |
+
num_heads_kv=None,
|
52 |
+
head_dim=None, # If None, use embed_dim // num_heads
|
53 |
+
mlp_dim=0,
|
54 |
+
qkv_proj_bias=True,
|
55 |
+
out_proj_bias=True,
|
56 |
+
softmax_scale=None,
|
57 |
+
causal=False,
|
58 |
+
layer_idx=None,
|
59 |
+
d_conv=0,
|
60 |
+
rotary_emb_dim=0,
|
61 |
+
rotary_emb_base=10000.0,
|
62 |
+
rotary_emb_interleaved=False,
|
63 |
+
device=None,
|
64 |
+
dtype=None,
|
65 |
+
) -> None:
|
66 |
+
"""
|
67 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
68 |
+
return_residual: whether to return the input x along with the output. This is for
|
69 |
+
performance reason: for post-norm architecture, returning the input allows us
|
70 |
+
to fuse the backward of nn.Linear with the residual connection.
|
71 |
+
"""
|
72 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
73 |
+
super().__init__()
|
74 |
+
self.embed_dim = embed_dim
|
75 |
+
self.layer_idx = layer_idx
|
76 |
+
self.d_conv = d_conv
|
77 |
+
self.rotary_emb_dim = rotary_emb_dim
|
78 |
+
self.softmax_scale = softmax_scale
|
79 |
+
self.causal = causal
|
80 |
+
|
81 |
+
self.num_heads = num_heads
|
82 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
83 |
+
assert (
|
84 |
+
self.num_heads % self.num_heads_kv == 0
|
85 |
+
), "num_heads must be divisible by num_heads_kv"
|
86 |
+
if head_dim is None:
|
87 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
88 |
+
self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
|
89 |
+
self.mlp_dim = math.ceil(mlp_dim / 256) * 256
|
90 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
91 |
+
out_dim = self.head_dim * self.num_heads
|
92 |
+
|
93 |
+
if self.rotary_emb_dim > 0:
|
94 |
+
assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
|
95 |
+
self.rotary_emb = RotaryEmbedding(
|
96 |
+
self.rotary_emb_dim,
|
97 |
+
base=rotary_emb_base,
|
98 |
+
interleaved=rotary_emb_interleaved,
|
99 |
+
device=device,
|
100 |
+
)
|
101 |
+
|
102 |
+
self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
|
103 |
+
if self.d_conv > 0:
|
104 |
+
self.conv1d = nn.Conv1d(
|
105 |
+
qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
|
106 |
+
**factory_kwargs
|
107 |
+
)
|
108 |
+
self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
109 |
+
|
110 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
111 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
112 |
+
device = self.out_proj.weight.device
|
113 |
+
if self.d_conv > 0:
|
114 |
+
conv_state = torch.zeros(
|
115 |
+
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
conv_state = None
|
119 |
+
kv_cache = torch.empty(
|
120 |
+
batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
|
121 |
+
)
|
122 |
+
return kv_cache, conv_state
|
123 |
+
|
124 |
+
def _update_kv_cache(self, kv, inference_params):
|
125 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
126 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
127 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
128 |
+
|
129 |
+
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
130 |
+
"""
|
131 |
+
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
132 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
133 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
134 |
+
"""
|
135 |
+
assert inference_params is not None and inference_params.seqlen_offset > 0
|
136 |
+
if self.rotary_emb_dim > 0:
|
137 |
+
self.rotary_emb._update_cos_sin_cache(
|
138 |
+
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
139 |
+
)
|
140 |
+
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
141 |
+
else:
|
142 |
+
rotary_cos, rotary_sin = None, None
|
143 |
+
batch = q.shape[0]
|
144 |
+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
|
145 |
+
kv_cache = kv_cache[:batch]
|
146 |
+
cache_seqlens = (
|
147 |
+
inference_params.lengths_per_sample[:batch]
|
148 |
+
if inference_params.lengths_per_sample is not None
|
149 |
+
else inference_params.seqlen_offset
|
150 |
+
)
|
151 |
+
assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
|
152 |
+
context = flash_attn_with_kvcache(
|
153 |
+
q,
|
154 |
+
kv_cache[:, :, 0],
|
155 |
+
kv_cache[:, :, 1],
|
156 |
+
kv[:, :, 0],
|
157 |
+
kv[:, :, 1],
|
158 |
+
rotary_cos=rotary_cos,
|
159 |
+
rotary_sin=rotary_sin,
|
160 |
+
cache_seqlens=cache_seqlens,
|
161 |
+
softmax_scale=self.softmax_scale,
|
162 |
+
causal=self.causal,
|
163 |
+
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
164 |
+
)
|
165 |
+
return context
|
166 |
+
|
167 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
168 |
+
"""Write kv to inference_params, then do attention"""
|
169 |
+
if (
|
170 |
+
inference_params.seqlen_offset == 0
|
171 |
+
or flash_attn_with_kvcache is None
|
172 |
+
):
|
173 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
174 |
+
kv = self._update_kv_cache(kv, inference_params)
|
175 |
+
k, v = kv.unbind(dim=-3)
|
176 |
+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
177 |
+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
178 |
+
return F.scaled_dot_product_attention(
|
179 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
|
180 |
+
).transpose(1, 2)
|
181 |
+
else:
|
182 |
+
batch = q.shape[0]
|
183 |
+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
|
184 |
+
kv_cache = kv_cache[:batch]
|
185 |
+
cache_seqlens = (
|
186 |
+
inference_params.lengths_per_sample[:batch]
|
187 |
+
if inference_params.lengths_per_sample is not None
|
188 |
+
else inference_params.seqlen_offset
|
189 |
+
)
|
190 |
+
return flash_attn_with_kvcache(
|
191 |
+
q,
|
192 |
+
kv_cache[:, :, 0],
|
193 |
+
kv_cache[:, :, 1],
|
194 |
+
kv[:, :, 0],
|
195 |
+
kv[:, :, 1],
|
196 |
+
cache_seqlens=cache_seqlens,
|
197 |
+
softmax_scale=self.softmax_scale,
|
198 |
+
causal=self.causal,
|
199 |
+
)
|
200 |
+
|
201 |
+
def forward(self, x, inference_params=None):
|
202 |
+
"""
|
203 |
+
Arguments:
|
204 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
205 |
+
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
206 |
+
is the is the sum of the sequence lengths in the batch.
|
207 |
+
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
208 |
+
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
209 |
+
"""
|
210 |
+
if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
|
211 |
+
inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
|
212 |
+
x.shape[0], inference_params.max_seqlen, dtype=x.dtype
|
213 |
+
)
|
214 |
+
seqlen_offset = (
|
215 |
+
0
|
216 |
+
if inference_params is None
|
217 |
+
else (
|
218 |
+
inference_params.lengths_per_sample
|
219 |
+
if inference_params.lengths_per_sample is not None
|
220 |
+
else inference_params.seqlen_offset
|
221 |
+
)
|
222 |
+
)
|
223 |
+
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
224 |
+
qkv = self.in_proj(x)
|
225 |
+
if self.mlp_dim > 0:
|
226 |
+
qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
|
227 |
+
x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
|
228 |
+
x_mlp = x_mlp_up * F.silu(x_mlp_gate)
|
229 |
+
if self.d_conv > 0:
|
230 |
+
# The inference code for conv1d is pretty messy, should clean it up
|
231 |
+
if (inference_params is None or inference_params.seqlen_offset == 0):
|
232 |
+
if causal_conv1d_fn is None:
|
233 |
+
qkv = rearrange(
|
234 |
+
self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
|
235 |
+
).contiguous()
|
236 |
+
else:
|
237 |
+
qkv = causal_conv1d_fn(
|
238 |
+
qkv.transpose(1, 2),
|
239 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
240 |
+
self.conv1d.bias
|
241 |
+
).transpose(1, 2)
|
242 |
+
if inference_params is not None:
|
243 |
+
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
|
244 |
+
# If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
245 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
246 |
+
qkv_t = rearrange(qkv, "b l d -> b d l")
|
247 |
+
conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
|
248 |
+
else:
|
249 |
+
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
|
250 |
+
assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
251 |
+
qkv = qkv.squeeze(1)
|
252 |
+
# Conv step
|
253 |
+
if causal_conv1d_update is None:
|
254 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
255 |
+
conv_state[:, :, -1] = qkv
|
256 |
+
qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
257 |
+
if self.conv1d.bias is not None:
|
258 |
+
qkv = qkv + self.conv1d.bias
|
259 |
+
else:
|
260 |
+
qkv = causal_conv1d_update(
|
261 |
+
qkv,
|
262 |
+
conv_state,
|
263 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
264 |
+
self.conv1d.bias
|
265 |
+
)
|
266 |
+
qkv = qkv.unsqueeze(1)
|
267 |
+
q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
|
268 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
269 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
270 |
+
if (
|
271 |
+
inference_params is None
|
272 |
+
or inference_params.seqlen_offset == 0
|
273 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
274 |
+
):
|
275 |
+
if self.rotary_emb_dim > 0:
|
276 |
+
q, kv = self.rotary_emb(
|
277 |
+
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
278 |
+
)
|
279 |
+
if inference_params is None:
|
280 |
+
k, v = kv.unbind(dim=-3)
|
281 |
+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
282 |
+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
283 |
+
context = F.scaled_dot_product_attention(
|
284 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
|
285 |
+
).transpose(1, 2)
|
286 |
+
else:
|
287 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
288 |
+
else:
|
289 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
290 |
+
context = rearrange(context, "... h d -> ... (h d)")
|
291 |
+
if self.mlp_dim > 0:
|
292 |
+
context = torch.cat([context, x_mlp], dim=-1)
|
293 |
+
out = self.out_proj(context)
|
294 |
+
return out
|
mamba/build/lib/mamba_ssm/modules/mlp.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class GatedMLP(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
in_features,
|
10 |
+
hidden_features=None,
|
11 |
+
out_features=None,
|
12 |
+
activation=F.silu,
|
13 |
+
bias=False,
|
14 |
+
multiple_of=128,
|
15 |
+
device=None,
|
16 |
+
dtype=None,
|
17 |
+
):
|
18 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
19 |
+
super().__init__()
|
20 |
+
out_features = out_features if out_features is not None else in_features
|
21 |
+
hidden_features = (
|
22 |
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
23 |
+
)
|
24 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
25 |
+
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
|
26 |
+
self.activation = activation
|
27 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
y = self.fc1(x)
|
31 |
+
y, gate = y.chunk(2, dim=-1)
|
32 |
+
y = y * self.activation(gate)
|
33 |
+
y = self.fc2(y)
|
34 |
+
return y
|
mamba/build/lib/mamba_ssm/modules/ssd_minimal.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Albert Gu and Tri Dao.
|
2 |
+
"""Minimal implementation of SSD.
|
3 |
+
|
4 |
+
This is the same as Listing 1 from the paper.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
12 |
+
|
13 |
+
|
14 |
+
def segsum_unstable(x):
|
15 |
+
"""Naive segment sum calculation."""
|
16 |
+
T = x.size(-1)
|
17 |
+
x_cumsum = torch.cumsum(x, dim=-1)
|
18 |
+
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
|
19 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
20 |
+
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
21 |
+
return x_segsum
|
22 |
+
|
23 |
+
def segsum(x):
|
24 |
+
"""More stable segment sum calculation."""
|
25 |
+
T = x.size(-1)
|
26 |
+
x = repeat(x, "... d -> ... d e", e=T)
|
27 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
|
28 |
+
x = x.masked_fill(~mask, 0)
|
29 |
+
x_segsum = torch.cumsum(x, dim=-2)
|
30 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
31 |
+
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
32 |
+
return x_segsum
|
33 |
+
|
34 |
+
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
|
35 |
+
"""
|
36 |
+
Arguments:
|
37 |
+
X: (batch, length, n_heads, d_head)
|
38 |
+
A: (batch, length, n_heads)
|
39 |
+
B: (batch, length, n_heads, d_state)
|
40 |
+
C: (batch, length, n_heads, d_state)
|
41 |
+
Return:
|
42 |
+
Y: (batch, length, n_heads, d_head)
|
43 |
+
"""
|
44 |
+
assert X.dtype == A.dtype == B.dtype == C.dtype
|
45 |
+
assert X.shape[1] % block_len == 0
|
46 |
+
|
47 |
+
# Rearrange into blocks/chunks
|
48 |
+
X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
|
49 |
+
|
50 |
+
A = rearrange(A, "b c l h -> b h c l")
|
51 |
+
A_cumsum = torch.cumsum(A, dim=-1)
|
52 |
+
|
53 |
+
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
54 |
+
L = torch.exp(segsum(A))
|
55 |
+
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
|
56 |
+
|
57 |
+
# 2. Compute the state for each intra-chunk
|
58 |
+
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
59 |
+
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
|
60 |
+
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
|
61 |
+
|
62 |
+
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
63 |
+
# (middle term of factorization of off-diag blocks; A terms)
|
64 |
+
if initial_states is None:
|
65 |
+
initial_states = torch.zeros_like(states[:, :1])
|
66 |
+
states = torch.cat([initial_states, states], dim=1)
|
67 |
+
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
68 |
+
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
69 |
+
states, final_state = new_states[:, :-1], new_states[:, -1]
|
70 |
+
|
71 |
+
# 4. Compute state -> output conversion per chunk
|
72 |
+
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
73 |
+
state_decay_out = torch.exp(A_cumsum)
|
74 |
+
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
|
75 |
+
|
76 |
+
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
77 |
+
Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
|
78 |
+
return Y, final_state
|
79 |
+
|
80 |
+
|
81 |
+
# Simple test
|
82 |
+
def test_correctness():
|
83 |
+
torch.manual_seed(42)
|
84 |
+
|
85 |
+
## Dimensions
|
86 |
+
# Denoted (B, T, Q, D, P) in the paper
|
87 |
+
batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
|
88 |
+
nheads = dim // headdim # (H) in the paper
|
89 |
+
ngroups = 1 # (G) in the paper
|
90 |
+
dstate = 64 # (N) in the paper
|
91 |
+
dtype = torch.float32
|
92 |
+
device = "cuda"
|
93 |
+
|
94 |
+
x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
|
95 |
+
dt = F.softplus(torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4).requires_grad_()
|
96 |
+
A = (-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))).requires_grad_()
|
97 |
+
B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
|
98 |
+
C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
|
99 |
+
D = torch.randn(nheads, dtype=dtype, device=device)
|
100 |
+
|
101 |
+
# Comparing fused version and minimal version
|
102 |
+
y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)
|
103 |
+
y_min, _ = ssd_minimal_discrete(x*dt.unsqueeze(-1), A*dt, B, C, chunk_size)
|
mamba/build/lib/mamba_ssm/ops/__init__.py
ADDED
File without changes
|
mamba/build/lib/mamba_ssm/ops/selective_scan_interface.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
6 |
+
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
|
9 |
+
try:
|
10 |
+
from causal_conv1d import causal_conv1d_fn
|
11 |
+
import causal_conv1d_cuda
|
12 |
+
except ImportError:
|
13 |
+
causal_conv1d_fn = None
|
14 |
+
causal_conv1d_cuda = None
|
15 |
+
|
16 |
+
import selective_scan_cuda
|
17 |
+
|
18 |
+
|
19 |
+
class SelectiveScanFn(torch.autograd.Function):
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
23 |
+
return_last_state=False):
|
24 |
+
if u.stride(-1) != 1:
|
25 |
+
u = u.contiguous()
|
26 |
+
if delta.stride(-1) != 1:
|
27 |
+
delta = delta.contiguous()
|
28 |
+
if D is not None:
|
29 |
+
D = D.contiguous()
|
30 |
+
if B.stride(-1) != 1:
|
31 |
+
B = B.contiguous()
|
32 |
+
if C.stride(-1) != 1:
|
33 |
+
C = C.contiguous()
|
34 |
+
if z is not None and z.stride(-1) != 1:
|
35 |
+
z = z.contiguous()
|
36 |
+
if B.dim() == 3:
|
37 |
+
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
38 |
+
ctx.squeeze_B = True
|
39 |
+
if C.dim() == 3:
|
40 |
+
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
41 |
+
ctx.squeeze_C = True
|
42 |
+
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
|
43 |
+
ctx.delta_softplus = delta_softplus
|
44 |
+
ctx.has_z = z is not None
|
45 |
+
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
46 |
+
if not ctx.has_z:
|
47 |
+
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
48 |
+
return out if not return_last_state else (out, last_state)
|
49 |
+
else:
|
50 |
+
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
51 |
+
out_z = rest[0]
|
52 |
+
return out_z if not return_last_state else (out_z, last_state)
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def backward(ctx, dout, *args):
|
56 |
+
if not ctx.has_z:
|
57 |
+
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
58 |
+
z = None
|
59 |
+
out = None
|
60 |
+
else:
|
61 |
+
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
62 |
+
if dout.stride(-1) != 1:
|
63 |
+
dout = dout.contiguous()
|
64 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
65 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
66 |
+
# Here we just pass in None and dz will be allocated in the C++ code.
|
67 |
+
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
68 |
+
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
|
69 |
+
False # option to recompute out_z, not used here
|
70 |
+
)
|
71 |
+
dz = rest[0] if ctx.has_z else None
|
72 |
+
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
73 |
+
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
74 |
+
return (du, ddelta, dA, dB, dC,
|
75 |
+
dD if D is not None else None,
|
76 |
+
dz,
|
77 |
+
ddelta_bias if delta_bias is not None else None,
|
78 |
+
None,
|
79 |
+
None)
|
80 |
+
|
81 |
+
|
82 |
+
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
83 |
+
return_last_state=False):
|
84 |
+
"""if return_last_state is True, returns (out, last_state)
|
85 |
+
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
86 |
+
not considered in the backward pass.
|
87 |
+
"""
|
88 |
+
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
|
89 |
+
|
90 |
+
|
91 |
+
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
92 |
+
return_last_state=False):
|
93 |
+
"""
|
94 |
+
u: r(B D L)
|
95 |
+
delta: r(B D L)
|
96 |
+
A: c(D N) or r(D N)
|
97 |
+
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
98 |
+
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
99 |
+
D: r(D)
|
100 |
+
z: r(B D L)
|
101 |
+
delta_bias: r(D), fp32
|
102 |
+
|
103 |
+
out: r(B D L)
|
104 |
+
last_state (optional): r(B D dstate) or c(B D dstate)
|
105 |
+
"""
|
106 |
+
dtype_in = u.dtype
|
107 |
+
u = u.float()
|
108 |
+
delta = delta.float()
|
109 |
+
if delta_bias is not None:
|
110 |
+
delta = delta + delta_bias[..., None].float()
|
111 |
+
if delta_softplus:
|
112 |
+
delta = F.softplus(delta)
|
113 |
+
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
114 |
+
is_variable_B = B.dim() >= 3
|
115 |
+
is_variable_C = C.dim() >= 3
|
116 |
+
if A.is_complex():
|
117 |
+
if is_variable_B:
|
118 |
+
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
119 |
+
if is_variable_C:
|
120 |
+
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
121 |
+
else:
|
122 |
+
B = B.float()
|
123 |
+
C = C.float()
|
124 |
+
x = A.new_zeros((batch, dim, dstate))
|
125 |
+
ys = []
|
126 |
+
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
127 |
+
if not is_variable_B:
|
128 |
+
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
129 |
+
else:
|
130 |
+
if B.dim() == 3:
|
131 |
+
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
132 |
+
else:
|
133 |
+
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
134 |
+
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
135 |
+
if is_variable_C and C.dim() == 4:
|
136 |
+
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
137 |
+
last_state = None
|
138 |
+
for i in range(u.shape[2]):
|
139 |
+
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
140 |
+
if not is_variable_C:
|
141 |
+
y = torch.einsum('bdn,dn->bd', x, C)
|
142 |
+
else:
|
143 |
+
if C.dim() == 3:
|
144 |
+
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
145 |
+
else:
|
146 |
+
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
147 |
+
if i == u.shape[2] - 1:
|
148 |
+
last_state = x
|
149 |
+
if y.is_complex():
|
150 |
+
y = y.real * 2
|
151 |
+
ys.append(y)
|
152 |
+
y = torch.stack(ys, dim=2) # (batch dim L)
|
153 |
+
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
154 |
+
if z is not None:
|
155 |
+
out = out * F.silu(z)
|
156 |
+
out = out.to(dtype=dtype_in)
|
157 |
+
return out if not return_last_state else (out, last_state)
|
158 |
+
|
159 |
+
|
160 |
+
class MambaInnerFn(torch.autograd.Function):
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
@custom_fwd
|
164 |
+
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
165 |
+
out_proj_weight, out_proj_bias,
|
166 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
167 |
+
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
168 |
+
"""
|
169 |
+
xz: (batch, dim, seqlen)
|
170 |
+
"""
|
171 |
+
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
172 |
+
assert checkpoint_lvl in [0, 1]
|
173 |
+
L = xz.shape[-1]
|
174 |
+
delta_rank = delta_proj_weight.shape[1]
|
175 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
176 |
+
if torch.is_autocast_enabled():
|
177 |
+
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
178 |
+
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
179 |
+
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
180 |
+
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
181 |
+
if out_proj_bias is not None else None)
|
182 |
+
if xz.stride(-1) != 1:
|
183 |
+
xz = xz.contiguous()
|
184 |
+
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
185 |
+
x, z = xz.chunk(2, dim=1)
|
186 |
+
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
187 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
188 |
+
x, conv1d_weight, conv1d_bias, None, None, None, True
|
189 |
+
)
|
190 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
191 |
+
# We want delta to have d as the slowest moving dimension
|
192 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
193 |
+
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
194 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
195 |
+
ctx.is_variable_B = B is None
|
196 |
+
ctx.is_variable_C = C is None
|
197 |
+
ctx.B_proj_bias_is_None = B_proj_bias is None
|
198 |
+
ctx.C_proj_bias_is_None = C_proj_bias is None
|
199 |
+
if B is None: # variable B
|
200 |
+
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
|
201 |
+
if B_proj_bias is not None:
|
202 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
203 |
+
if not A.is_complex():
|
204 |
+
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
205 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
206 |
+
else:
|
207 |
+
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
208 |
+
else:
|
209 |
+
if B.stride(-1) != 1:
|
210 |
+
B = B.contiguous()
|
211 |
+
if C is None: # variable C
|
212 |
+
C = x_dbl[:, -d_state:] # (bl dstate)
|
213 |
+
if C_proj_bias is not None:
|
214 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
215 |
+
if not A.is_complex():
|
216 |
+
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
217 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
218 |
+
else:
|
219 |
+
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
220 |
+
else:
|
221 |
+
if C.stride(-1) != 1:
|
222 |
+
C = C.contiguous()
|
223 |
+
if D is not None:
|
224 |
+
D = D.contiguous()
|
225 |
+
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
226 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
227 |
+
)
|
228 |
+
ctx.delta_softplus = delta_softplus
|
229 |
+
ctx.out_proj_bias_is_None = out_proj_bias is None
|
230 |
+
ctx.checkpoint_lvl = checkpoint_lvl
|
231 |
+
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
232 |
+
conv1d_out, delta = None, None
|
233 |
+
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
234 |
+
delta_proj_weight, out_proj_weight, conv1d_out, delta,
|
235 |
+
A, B, C, D, delta_bias, scan_intermediates, out)
|
236 |
+
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
237 |
+
|
238 |
+
@staticmethod
|
239 |
+
@custom_bwd
|
240 |
+
def backward(ctx, dout):
|
241 |
+
# dout: (batch, seqlen, dim)
|
242 |
+
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
243 |
+
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
|
244 |
+
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
|
245 |
+
L = xz.shape[-1]
|
246 |
+
delta_rank = delta_proj_weight.shape[1]
|
247 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
248 |
+
x, z = xz.chunk(2, dim=1)
|
249 |
+
if dout.stride(-1) != 1:
|
250 |
+
dout = dout.contiguous()
|
251 |
+
if ctx.checkpoint_lvl == 1:
|
252 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
253 |
+
x, conv1d_weight, conv1d_bias, None, None, None, True
|
254 |
+
)
|
255 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
256 |
+
"d (b l) -> b d l", l = L)
|
257 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
258 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
259 |
+
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
260 |
+
dx, dz = dxz.chunk(2, dim=1)
|
261 |
+
dout = rearrange(dout, "b l e -> e (b l)")
|
262 |
+
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
263 |
+
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
264 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
|
265 |
+
ctx.delta_softplus,
|
266 |
+
True # option to recompute out_z
|
267 |
+
)
|
268 |
+
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
269 |
+
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
270 |
+
dD = dD if D is not None else None
|
271 |
+
dx_dbl = torch.empty_like(x_dbl)
|
272 |
+
dB_proj_bias = None
|
273 |
+
if ctx.is_variable_B:
|
274 |
+
if not A.is_complex():
|
275 |
+
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
276 |
+
else:
|
277 |
+
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
278 |
+
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
279 |
+
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
|
280 |
+
dB = None
|
281 |
+
dC_proj_bias = None
|
282 |
+
if ctx.is_variable_C:
|
283 |
+
if not A.is_complex():
|
284 |
+
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
285 |
+
else:
|
286 |
+
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
287 |
+
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
288 |
+
dx_dbl[:, -d_state:] = dC # (bl d)
|
289 |
+
dC = None
|
290 |
+
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
291 |
+
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
292 |
+
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
293 |
+
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
294 |
+
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
295 |
+
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
296 |
+
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
297 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
298 |
+
# backward of conv1d with the backward of chunk).
|
299 |
+
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
300 |
+
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
|
301 |
+
)
|
302 |
+
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
303 |
+
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
304 |
+
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
305 |
+
dout_proj_weight, dout_proj_bias,
|
306 |
+
dA, dB, dC, dD,
|
307 |
+
ddelta_bias if delta_bias is not None else None,
|
308 |
+
dB_proj_bias, dC_proj_bias, None)
|
309 |
+
|
310 |
+
|
311 |
+
def mamba_inner_fn(
|
312 |
+
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
313 |
+
out_proj_weight, out_proj_bias,
|
314 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
315 |
+
C_proj_bias=None, delta_softplus=True
|
316 |
+
):
|
317 |
+
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
318 |
+
out_proj_weight, out_proj_bias,
|
319 |
+
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
320 |
+
|
321 |
+
|
322 |
+
def mamba_inner_ref(
|
323 |
+
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
324 |
+
out_proj_weight, out_proj_bias,
|
325 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
326 |
+
C_proj_bias=None, delta_softplus=True
|
327 |
+
):
|
328 |
+
assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
|
329 |
+
L = xz.shape[-1]
|
330 |
+
delta_rank = delta_proj_weight.shape[1]
|
331 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
332 |
+
x, z = xz.chunk(2, dim=1)
|
333 |
+
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
|
334 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
335 |
+
# We want delta to have d as the slowest moving dimension
|
336 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
337 |
+
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
338 |
+
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
339 |
+
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
340 |
+
if B is None: # variable B
|
341 |
+
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
|
342 |
+
if B_proj_bias is not None:
|
343 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
344 |
+
if not A.is_complex():
|
345 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
346 |
+
else:
|
347 |
+
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
348 |
+
if C is None: # variable B
|
349 |
+
C = x_dbl[:, -d_state:] # (bl d)
|
350 |
+
if C_proj_bias is not None:
|
351 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
352 |
+
if not A.is_complex():
|
353 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
354 |
+
else:
|
355 |
+
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
356 |
+
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
|
357 |
+
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
mamba/build/lib/mamba_ssm/ops/triton/__init__.py
ADDED
File without changes
|
mamba/build/lib/mamba_ssm/ops/triton/k_activations.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import triton
|
6 |
+
import triton.language as tl
|
7 |
+
|
8 |
+
|
9 |
+
@triton.autotune(
|
10 |
+
configs=[
|
11 |
+
triton.Config({'BLOCK_N': 32}),
|
12 |
+
triton.Config({'BLOCK_N': 64}),
|
13 |
+
triton.Config({'BLOCK_N': 128}),
|
14 |
+
triton.Config({'BLOCK_N': 256}),
|
15 |
+
triton.Config({'BLOCK_N': 512}),
|
16 |
+
triton.Config({'BLOCK_N': 1024}),
|
17 |
+
],
|
18 |
+
key=['ncols'],
|
19 |
+
)
|
20 |
+
@triton.jit
|
21 |
+
def _swiglu_fwd_kernel(
|
22 |
+
X,
|
23 |
+
Y,
|
24 |
+
OUT,
|
25 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
26 |
+
stride_y_row,
|
27 |
+
stride_out_row,
|
28 |
+
ncols,
|
29 |
+
BLOCK_N: tl.constexpr,
|
30 |
+
):
|
31 |
+
# Map the program id to the row of X and Y it should compute.
|
32 |
+
row = tl.program_id(0)
|
33 |
+
start_col = tl.program_id(1) * BLOCK_N
|
34 |
+
X += row * stride_x_row
|
35 |
+
Y += row * stride_y_row
|
36 |
+
OUT += row * stride_out_row
|
37 |
+
cols = start_col + tl.arange(0, BLOCK_N)
|
38 |
+
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
39 |
+
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
40 |
+
out = x * tl.sigmoid(x) * y
|
41 |
+
tl.store(OUT + cols, out, mask=cols < ncols)
|
42 |
+
|
43 |
+
|
44 |
+
def _swiglu_fwd(xy, out=None):
|
45 |
+
if xy.stride(-1) != 1:
|
46 |
+
xy = xy.contiguous()
|
47 |
+
batch_shape = xy.shape[:-1]
|
48 |
+
xy = xy.reshape(-1, xy.shape[-1])
|
49 |
+
x, y = xy.chunk(2, dim=-1)
|
50 |
+
if out is None:
|
51 |
+
out = torch.empty_like(x)
|
52 |
+
else:
|
53 |
+
out = out.reshape(-1, out.shape[-1])
|
54 |
+
assert out.shape == x.shape
|
55 |
+
assert out.stride(-1) == 1
|
56 |
+
M, N = x.shape
|
57 |
+
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
|
58 |
+
with torch.cuda.device(x.device.index):
|
59 |
+
_swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
|
60 |
+
return out.reshape(*batch_shape, out.shape[-1])
|
61 |
+
|
62 |
+
|
63 |
+
@triton.autotune(
|
64 |
+
configs=[
|
65 |
+
triton.Config({'BLOCK_N': 32}),
|
66 |
+
triton.Config({'BLOCK_N': 64}),
|
67 |
+
triton.Config({'BLOCK_N': 128}),
|
68 |
+
triton.Config({'BLOCK_N': 256}),
|
69 |
+
triton.Config({'BLOCK_N': 512}),
|
70 |
+
triton.Config({'BLOCK_N': 1024}),
|
71 |
+
],
|
72 |
+
key=['ncols'],
|
73 |
+
)
|
74 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
|
75 |
+
@triton.jit
|
76 |
+
def _swiglu_bwd_kernel(
|
77 |
+
X,
|
78 |
+
Y,
|
79 |
+
DOUT,
|
80 |
+
OUT,
|
81 |
+
DX,
|
82 |
+
DY,
|
83 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
84 |
+
stride_y_row,
|
85 |
+
stride_dout_row,
|
86 |
+
stride_out_row,
|
87 |
+
stride_dx_row,
|
88 |
+
stride_dy_row,
|
89 |
+
ncols,
|
90 |
+
BLOCK_N: tl.constexpr,
|
91 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
92 |
+
):
|
93 |
+
# Map the program id to the row of X and Y it should compute.
|
94 |
+
row = tl.program_id(0)
|
95 |
+
start_col = tl.program_id(1) * BLOCK_N
|
96 |
+
X += row * stride_x_row
|
97 |
+
Y += row * stride_y_row
|
98 |
+
DOUT += row * stride_dout_row
|
99 |
+
if RECOMPUTE_OUTPUT:
|
100 |
+
OUT += row * stride_out_row
|
101 |
+
DX += row * stride_dx_row
|
102 |
+
DY += row * stride_dy_row
|
103 |
+
cols = start_col + tl.arange(0, BLOCK_N)
|
104 |
+
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
105 |
+
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
106 |
+
dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
107 |
+
x_sigmoid = tl.sigmoid(x)
|
108 |
+
dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
|
109 |
+
dy = x * x_sigmoid * dout
|
110 |
+
tl.store(DX + cols, dx, mask=cols < ncols)
|
111 |
+
tl.store(DY + cols, dy, mask=cols < ncols)
|
112 |
+
if RECOMPUTE_OUTPUT:
|
113 |
+
out = x * x_sigmoid * y
|
114 |
+
tl.store(OUT + cols, out, mask=cols < ncols)
|
115 |
+
|
116 |
+
|
117 |
+
def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
|
118 |
+
if xy.stride(-1) != 1:
|
119 |
+
xy = xy.contiguous()
|
120 |
+
if dout.stride(-1) != 1:
|
121 |
+
dout = dout.contiguous()
|
122 |
+
batch_shape = xy.shape[:-1]
|
123 |
+
xy = xy.reshape(-1, xy.shape[-1])
|
124 |
+
x, y = xy.chunk(2, dim=-1)
|
125 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
126 |
+
assert dout.shape == x.shape
|
127 |
+
if dxy is None:
|
128 |
+
dxy = torch.empty_like(xy)
|
129 |
+
else:
|
130 |
+
dxy = dxy.reshape(-1, dxy.shape[-1])
|
131 |
+
assert dxy.shape == xy.shape
|
132 |
+
dx, dy = dxy.chunk(2, dim=-1)
|
133 |
+
assert dx.stride(-1) == 1
|
134 |
+
assert dy.stride(-1) == 1
|
135 |
+
if recompute_output:
|
136 |
+
if out is None:
|
137 |
+
out = torch.empty_like(x)
|
138 |
+
else:
|
139 |
+
out = out.reshape(-1, out.shape[-1])
|
140 |
+
assert out.shape == x.shape
|
141 |
+
assert out.stride(-1) == 1
|
142 |
+
M, N = x.shape
|
143 |
+
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
|
144 |
+
with torch.cuda.device(x.device.index):
|
145 |
+
_swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
|
146 |
+
x.stride(0), y.stride(0), dout.stride(0),
|
147 |
+
out.stride(0) if recompute_output else 0,
|
148 |
+
dx.stride(0), dy.stride(0),
|
149 |
+
N)
|
150 |
+
if not recompute_output:
|
151 |
+
return dxy.reshape(*batch_shape, dxy.shape[-1])
|
152 |
+
else:
|
153 |
+
return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
|
154 |
+
|
155 |
+
|
156 |
+
class SwiGLU(torch.autograd.Function):
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
def forward(ctx, xy):
|
160 |
+
ctx.save_for_backward(xy)
|
161 |
+
return _swiglu_fwd(xy)
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def backward(ctx, dout):
|
165 |
+
xy, = ctx.saved_tensors
|
166 |
+
return _swiglu_bwd(xy, dout)
|
167 |
+
|
168 |
+
|
169 |
+
swiglu = SwiGLU.apply
|
mamba/build/lib/mamba_ssm/ops/triton/layer_norm.py
ADDED
@@ -0,0 +1,1113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# Implement dropout + residual + layer_norm / rms_norm.
|
3 |
+
|
4 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
5 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
6 |
+
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
7 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
8 |
+
|
9 |
+
import math
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch.cuda.amp import custom_fwd, custom_bwd
|
15 |
+
|
16 |
+
import triton
|
17 |
+
import triton.language as tl
|
18 |
+
|
19 |
+
|
20 |
+
def layer_norm_ref(
|
21 |
+
x,
|
22 |
+
weight,
|
23 |
+
bias,
|
24 |
+
residual=None,
|
25 |
+
x1=None,
|
26 |
+
weight1=None,
|
27 |
+
bias1=None,
|
28 |
+
eps=1e-6,
|
29 |
+
dropout_p=0.0,
|
30 |
+
rowscale=None,
|
31 |
+
prenorm=False,
|
32 |
+
dropout_mask=None,
|
33 |
+
dropout_mask1=None,
|
34 |
+
upcast=False,
|
35 |
+
):
|
36 |
+
dtype = x.dtype
|
37 |
+
if upcast:
|
38 |
+
x = x.float()
|
39 |
+
weight = weight.float()
|
40 |
+
bias = bias.float() if bias is not None else None
|
41 |
+
residual = residual.float() if residual is not None else residual
|
42 |
+
x1 = x1.float() if x1 is not None else None
|
43 |
+
weight1 = weight1.float() if weight1 is not None else None
|
44 |
+
bias1 = bias1.float() if bias1 is not None else None
|
45 |
+
if x1 is not None:
|
46 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
47 |
+
if rowscale is not None:
|
48 |
+
x = x * rowscale[..., None]
|
49 |
+
if dropout_p > 0.0:
|
50 |
+
if dropout_mask is not None:
|
51 |
+
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
52 |
+
else:
|
53 |
+
x = F.dropout(x, p=dropout_p)
|
54 |
+
if x1 is not None:
|
55 |
+
if dropout_mask1 is not None:
|
56 |
+
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
57 |
+
else:
|
58 |
+
x1 = F.dropout(x1, p=dropout_p)
|
59 |
+
if x1 is not None:
|
60 |
+
x = x + x1
|
61 |
+
if residual is not None:
|
62 |
+
x = (x + residual).to(x.dtype)
|
63 |
+
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
|
64 |
+
dtype
|
65 |
+
)
|
66 |
+
if weight1 is None:
|
67 |
+
return out if not prenorm else (out, x)
|
68 |
+
else:
|
69 |
+
out1 = F.layer_norm(
|
70 |
+
x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
|
71 |
+
).to(dtype)
|
72 |
+
return (out, out1) if not prenorm else (out, out1, x)
|
73 |
+
|
74 |
+
|
75 |
+
def rms_norm_ref(
|
76 |
+
x,
|
77 |
+
weight,
|
78 |
+
bias,
|
79 |
+
residual=None,
|
80 |
+
x1=None,
|
81 |
+
weight1=None,
|
82 |
+
bias1=None,
|
83 |
+
eps=1e-6,
|
84 |
+
dropout_p=0.0,
|
85 |
+
rowscale=None,
|
86 |
+
prenorm=False,
|
87 |
+
dropout_mask=None,
|
88 |
+
dropout_mask1=None,
|
89 |
+
upcast=False,
|
90 |
+
):
|
91 |
+
dtype = x.dtype
|
92 |
+
if upcast:
|
93 |
+
x = x.float()
|
94 |
+
weight = weight.float()
|
95 |
+
bias = bias.float() if bias is not None else None
|
96 |
+
residual = residual.float() if residual is not None else residual
|
97 |
+
x1 = x1.float() if x1 is not None else None
|
98 |
+
weight1 = weight1.float() if weight1 is not None else None
|
99 |
+
bias1 = bias1.float() if bias1 is not None else None
|
100 |
+
if x1 is not None:
|
101 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
102 |
+
if rowscale is not None:
|
103 |
+
x = x * rowscale[..., None]
|
104 |
+
if dropout_p > 0.0:
|
105 |
+
if dropout_mask is not None:
|
106 |
+
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
107 |
+
else:
|
108 |
+
x = F.dropout(x, p=dropout_p)
|
109 |
+
if x1 is not None:
|
110 |
+
if dropout_mask1 is not None:
|
111 |
+
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
112 |
+
else:
|
113 |
+
x1 = F.dropout(x1, p=dropout_p)
|
114 |
+
if x1 is not None:
|
115 |
+
x = x + x1
|
116 |
+
if residual is not None:
|
117 |
+
x = (x + residual).to(x.dtype)
|
118 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
119 |
+
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
|
120 |
+
if weight1 is None:
|
121 |
+
return out if not prenorm else (out, x)
|
122 |
+
else:
|
123 |
+
out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
|
124 |
+
dtype
|
125 |
+
)
|
126 |
+
return (out, out1) if not prenorm else (out, out1, x)
|
127 |
+
|
128 |
+
def config_prune(configs):
|
129 |
+
|
130 |
+
if torch.version.hip:
|
131 |
+
try:
|
132 |
+
# set warp size based on gcn architecure
|
133 |
+
gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
|
134 |
+
if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
|
135 |
+
# radeon
|
136 |
+
warp_size = 32
|
137 |
+
else:
|
138 |
+
# instinct
|
139 |
+
warp_size = 64
|
140 |
+
except AttributeError as e:
|
141 |
+
# fall back to crude method to set warp size
|
142 |
+
device_name = torch.cuda.get_device_properties(0).name
|
143 |
+
if 'instinct' in device_name.lower():
|
144 |
+
warp_size = 64
|
145 |
+
else:
|
146 |
+
warp_size = 32
|
147 |
+
warnings.warn(f"{e}, warp size set to {warp_size} based on device name: {device_name}", UserWarning)
|
148 |
+
|
149 |
+
else:
|
150 |
+
# cuda
|
151 |
+
warp_size = 32
|
152 |
+
|
153 |
+
max_block_sz = 1024
|
154 |
+
max_num_warps = max_block_sz // warp_size
|
155 |
+
pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
|
156 |
+
return pruned_configs
|
157 |
+
|
158 |
+
configs_autotune = [
|
159 |
+
triton.Config({}, num_warps=1),
|
160 |
+
triton.Config({}, num_warps=2),
|
161 |
+
triton.Config({}, num_warps=4),
|
162 |
+
triton.Config({}, num_warps=8),
|
163 |
+
triton.Config({}, num_warps=16),
|
164 |
+
triton.Config({}, num_warps=32),
|
165 |
+
]
|
166 |
+
|
167 |
+
pruned_configs_autotune = config_prune(configs_autotune)
|
168 |
+
|
169 |
+
@triton.autotune(
|
170 |
+
configs = pruned_configs_autotune,
|
171 |
+
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
172 |
+
)
|
173 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
174 |
+
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
175 |
+
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
|
176 |
+
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
|
177 |
+
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
|
178 |
+
@triton.jit
|
179 |
+
def _layer_norm_fwd_1pass_kernel(
|
180 |
+
X, # pointer to the input
|
181 |
+
Y, # pointer to the output
|
182 |
+
W, # pointer to the weights
|
183 |
+
B, # pointer to the biases
|
184 |
+
RESIDUAL, # pointer to the residual
|
185 |
+
X1,
|
186 |
+
W1,
|
187 |
+
B1,
|
188 |
+
Y1,
|
189 |
+
RESIDUAL_OUT, # pointer to the residual
|
190 |
+
ROWSCALE,
|
191 |
+
SEEDS, # Dropout seeds for each row
|
192 |
+
DROPOUT_MASK,
|
193 |
+
Mean, # pointer to the mean
|
194 |
+
Rstd, # pointer to the 1/std
|
195 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
196 |
+
stride_y_row,
|
197 |
+
stride_res_row,
|
198 |
+
stride_res_out_row,
|
199 |
+
stride_x1_row,
|
200 |
+
stride_y1_row,
|
201 |
+
M, # number of rows in X
|
202 |
+
N, # number of columns in X
|
203 |
+
eps, # epsilon to avoid division by zero
|
204 |
+
dropout_p, # Dropout probability
|
205 |
+
IS_RMS_NORM: tl.constexpr,
|
206 |
+
BLOCK_N: tl.constexpr,
|
207 |
+
HAS_RESIDUAL: tl.constexpr,
|
208 |
+
STORE_RESIDUAL_OUT: tl.constexpr,
|
209 |
+
HAS_BIAS: tl.constexpr,
|
210 |
+
HAS_DROPOUT: tl.constexpr,
|
211 |
+
STORE_DROPOUT_MASK: tl.constexpr,
|
212 |
+
HAS_ROWSCALE: tl.constexpr,
|
213 |
+
HAS_X1: tl.constexpr,
|
214 |
+
HAS_W1: tl.constexpr,
|
215 |
+
HAS_B1: tl.constexpr,
|
216 |
+
):
|
217 |
+
# Map the program id to the row of X and Y it should compute.
|
218 |
+
row = tl.program_id(0)
|
219 |
+
X += row * stride_x_row
|
220 |
+
Y += row * stride_y_row
|
221 |
+
if HAS_RESIDUAL:
|
222 |
+
RESIDUAL += row * stride_res_row
|
223 |
+
if STORE_RESIDUAL_OUT:
|
224 |
+
RESIDUAL_OUT += row * stride_res_out_row
|
225 |
+
if HAS_X1:
|
226 |
+
X1 += row * stride_x1_row
|
227 |
+
if HAS_W1:
|
228 |
+
Y1 += row * stride_y1_row
|
229 |
+
# Compute mean and variance
|
230 |
+
cols = tl.arange(0, BLOCK_N)
|
231 |
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
232 |
+
if HAS_ROWSCALE:
|
233 |
+
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
234 |
+
x *= rowscale
|
235 |
+
if HAS_DROPOUT:
|
236 |
+
# Compute dropout mask
|
237 |
+
# 7 rounds is good enough, and reduces register pressure
|
238 |
+
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
239 |
+
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
240 |
+
if STORE_DROPOUT_MASK:
|
241 |
+
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
242 |
+
if HAS_X1:
|
243 |
+
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
|
244 |
+
if HAS_ROWSCALE:
|
245 |
+
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
|
246 |
+
x1 *= rowscale
|
247 |
+
if HAS_DROPOUT:
|
248 |
+
# Compute dropout mask
|
249 |
+
# 7 rounds is good enough, and reduces register pressure
|
250 |
+
keep_mask = (
|
251 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
252 |
+
)
|
253 |
+
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
|
254 |
+
if STORE_DROPOUT_MASK:
|
255 |
+
tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
|
256 |
+
x += x1
|
257 |
+
if HAS_RESIDUAL:
|
258 |
+
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
259 |
+
x += residual
|
260 |
+
if STORE_RESIDUAL_OUT:
|
261 |
+
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
262 |
+
if not IS_RMS_NORM:
|
263 |
+
mean = tl.sum(x, axis=0) / N
|
264 |
+
tl.store(Mean + row, mean)
|
265 |
+
xbar = tl.where(cols < N, x - mean, 0.0)
|
266 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
267 |
+
else:
|
268 |
+
xbar = tl.where(cols < N, x, 0.0)
|
269 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
270 |
+
rstd = 1 / tl.sqrt(var + eps)
|
271 |
+
tl.store(Rstd + row, rstd)
|
272 |
+
# Normalize and apply linear transformation
|
273 |
+
mask = cols < N
|
274 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
275 |
+
if HAS_BIAS:
|
276 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
277 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
278 |
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
279 |
+
# Write output
|
280 |
+
tl.store(Y + cols, y, mask=mask)
|
281 |
+
if HAS_W1:
|
282 |
+
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
283 |
+
if HAS_B1:
|
284 |
+
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
|
285 |
+
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
|
286 |
+
tl.store(Y1 + cols, y1, mask=mask)
|
287 |
+
|
288 |
+
|
289 |
+
def _layer_norm_fwd(
|
290 |
+
x,
|
291 |
+
weight,
|
292 |
+
bias,
|
293 |
+
eps,
|
294 |
+
residual=None,
|
295 |
+
x1=None,
|
296 |
+
weight1=None,
|
297 |
+
bias1=None,
|
298 |
+
dropout_p=0.0,
|
299 |
+
rowscale=None,
|
300 |
+
out_dtype=None,
|
301 |
+
residual_dtype=None,
|
302 |
+
is_rms_norm=False,
|
303 |
+
return_dropout_mask=False,
|
304 |
+
):
|
305 |
+
if residual is not None:
|
306 |
+
residual_dtype = residual.dtype
|
307 |
+
M, N = x.shape
|
308 |
+
assert x.stride(-1) == 1
|
309 |
+
if residual is not None:
|
310 |
+
assert residual.stride(-1) == 1
|
311 |
+
assert residual.shape == (M, N)
|
312 |
+
assert weight.shape == (N,)
|
313 |
+
assert weight.stride(-1) == 1
|
314 |
+
if bias is not None:
|
315 |
+
assert bias.stride(-1) == 1
|
316 |
+
assert bias.shape == (N,)
|
317 |
+
if x1 is not None:
|
318 |
+
assert x1.shape == x.shape
|
319 |
+
assert rowscale is None
|
320 |
+
assert x1.stride(-1) == 1
|
321 |
+
if weight1 is not None:
|
322 |
+
assert weight1.shape == (N,)
|
323 |
+
assert weight1.stride(-1) == 1
|
324 |
+
if bias1 is not None:
|
325 |
+
assert bias1.shape == (N,)
|
326 |
+
assert bias1.stride(-1) == 1
|
327 |
+
if rowscale is not None:
|
328 |
+
assert rowscale.is_contiguous()
|
329 |
+
assert rowscale.shape == (M,)
|
330 |
+
# allocate output
|
331 |
+
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
332 |
+
assert y.stride(-1) == 1
|
333 |
+
if weight1 is not None:
|
334 |
+
y1 = torch.empty_like(y)
|
335 |
+
assert y1.stride(-1) == 1
|
336 |
+
else:
|
337 |
+
y1 = None
|
338 |
+
if (
|
339 |
+
residual is not None
|
340 |
+
or (residual_dtype is not None and residual_dtype != x.dtype)
|
341 |
+
or dropout_p > 0.0
|
342 |
+
or rowscale is not None
|
343 |
+
or x1 is not None
|
344 |
+
):
|
345 |
+
residual_out = torch.empty(
|
346 |
+
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
347 |
+
)
|
348 |
+
assert residual_out.stride(-1) == 1
|
349 |
+
else:
|
350 |
+
residual_out = None
|
351 |
+
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
352 |
+
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
353 |
+
if dropout_p > 0.0:
|
354 |
+
seeds = torch.randint(
|
355 |
+
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
356 |
+
)
|
357 |
+
else:
|
358 |
+
seeds = None
|
359 |
+
if return_dropout_mask and dropout_p > 0.0:
|
360 |
+
dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
|
361 |
+
else:
|
362 |
+
dropout_mask = None
|
363 |
+
# Less than 64KB per feature: enqueue fused kernel
|
364 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
365 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
366 |
+
if N > BLOCK_N:
|
367 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
368 |
+
with torch.cuda.device(x.device.index):
|
369 |
+
_layer_norm_fwd_1pass_kernel[(M,)](
|
370 |
+
x,
|
371 |
+
y,
|
372 |
+
weight,
|
373 |
+
bias,
|
374 |
+
residual,
|
375 |
+
x1,
|
376 |
+
weight1,
|
377 |
+
bias1,
|
378 |
+
y1,
|
379 |
+
residual_out,
|
380 |
+
rowscale,
|
381 |
+
seeds,
|
382 |
+
dropout_mask,
|
383 |
+
mean,
|
384 |
+
rstd,
|
385 |
+
x.stride(0),
|
386 |
+
y.stride(0),
|
387 |
+
residual.stride(0) if residual is not None else 0,
|
388 |
+
residual_out.stride(0) if residual_out is not None else 0,
|
389 |
+
x1.stride(0) if x1 is not None else 0,
|
390 |
+
y1.stride(0) if y1 is not None else 0,
|
391 |
+
M,
|
392 |
+
N,
|
393 |
+
eps,
|
394 |
+
dropout_p,
|
395 |
+
is_rms_norm,
|
396 |
+
BLOCK_N,
|
397 |
+
residual is not None,
|
398 |
+
residual_out is not None,
|
399 |
+
bias is not None,
|
400 |
+
dropout_p > 0.0,
|
401 |
+
dropout_mask is not None,
|
402 |
+
rowscale is not None,
|
403 |
+
)
|
404 |
+
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
405 |
+
if dropout_mask is not None and x1 is not None:
|
406 |
+
dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
|
407 |
+
else:
|
408 |
+
dropout_mask1 = None
|
409 |
+
return (
|
410 |
+
y,
|
411 |
+
y1,
|
412 |
+
mean,
|
413 |
+
rstd,
|
414 |
+
residual_out if residual_out is not None else x,
|
415 |
+
seeds,
|
416 |
+
dropout_mask,
|
417 |
+
dropout_mask1,
|
418 |
+
)
|
419 |
+
|
420 |
+
|
421 |
+
@triton.autotune(
|
422 |
+
configs=pruned_configs_autotune,
|
423 |
+
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
|
424 |
+
)
|
425 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
426 |
+
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
427 |
+
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
428 |
+
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
|
429 |
+
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
|
430 |
+
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
|
431 |
+
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
|
432 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
433 |
+
@triton.jit
|
434 |
+
def _layer_norm_bwd_kernel(
|
435 |
+
X, # pointer to the input
|
436 |
+
W, # pointer to the weights
|
437 |
+
B, # pointer to the biases
|
438 |
+
Y, # pointer to the output to be recomputed
|
439 |
+
DY, # pointer to the output gradient
|
440 |
+
DX, # pointer to the input gradient
|
441 |
+
DW, # pointer to the partial sum of weights gradient
|
442 |
+
DB, # pointer to the partial sum of biases gradient
|
443 |
+
DRESIDUAL,
|
444 |
+
W1,
|
445 |
+
DY1,
|
446 |
+
DX1,
|
447 |
+
DW1,
|
448 |
+
DB1,
|
449 |
+
DRESIDUAL_IN,
|
450 |
+
ROWSCALE,
|
451 |
+
SEEDS,
|
452 |
+
Mean, # pointer to the mean
|
453 |
+
Rstd, # pointer to the 1/std
|
454 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
455 |
+
stride_y_row,
|
456 |
+
stride_dy_row,
|
457 |
+
stride_dx_row,
|
458 |
+
stride_dres_row,
|
459 |
+
stride_dy1_row,
|
460 |
+
stride_dx1_row,
|
461 |
+
stride_dres_in_row,
|
462 |
+
M, # number of rows in X
|
463 |
+
N, # number of columns in X
|
464 |
+
eps, # epsilon to avoid division by zero
|
465 |
+
dropout_p,
|
466 |
+
rows_per_program,
|
467 |
+
IS_RMS_NORM: tl.constexpr,
|
468 |
+
BLOCK_N: tl.constexpr,
|
469 |
+
HAS_DRESIDUAL: tl.constexpr,
|
470 |
+
STORE_DRESIDUAL: tl.constexpr,
|
471 |
+
HAS_BIAS: tl.constexpr,
|
472 |
+
HAS_DROPOUT: tl.constexpr,
|
473 |
+
HAS_ROWSCALE: tl.constexpr,
|
474 |
+
HAS_DY1: tl.constexpr,
|
475 |
+
HAS_DX1: tl.constexpr,
|
476 |
+
HAS_B1: tl.constexpr,
|
477 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
478 |
+
):
|
479 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
480 |
+
row_block_id = tl.program_id(0)
|
481 |
+
row_start = row_block_id * rows_per_program
|
482 |
+
# Do not early exit if row_start >= M, because we need to write DW and DB
|
483 |
+
cols = tl.arange(0, BLOCK_N)
|
484 |
+
mask = cols < N
|
485 |
+
X += row_start * stride_x_row
|
486 |
+
if HAS_DRESIDUAL:
|
487 |
+
DRESIDUAL += row_start * stride_dres_row
|
488 |
+
if STORE_DRESIDUAL:
|
489 |
+
DRESIDUAL_IN += row_start * stride_dres_in_row
|
490 |
+
DY += row_start * stride_dy_row
|
491 |
+
DX += row_start * stride_dx_row
|
492 |
+
if HAS_DY1:
|
493 |
+
DY1 += row_start * stride_dy1_row
|
494 |
+
if HAS_DX1:
|
495 |
+
DX1 += row_start * stride_dx1_row
|
496 |
+
if RECOMPUTE_OUTPUT:
|
497 |
+
Y += row_start * stride_y_row
|
498 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
499 |
+
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
500 |
+
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
501 |
+
if HAS_DY1:
|
502 |
+
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
503 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
504 |
+
if HAS_BIAS:
|
505 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
506 |
+
if HAS_DY1:
|
507 |
+
dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
508 |
+
if HAS_B1:
|
509 |
+
db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
510 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
511 |
+
for row in range(row_start, row_end):
|
512 |
+
# Load data to SRAM
|
513 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
514 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
515 |
+
if HAS_DY1:
|
516 |
+
dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
|
517 |
+
if not IS_RMS_NORM:
|
518 |
+
mean = tl.load(Mean + row)
|
519 |
+
rstd = tl.load(Rstd + row)
|
520 |
+
# Compute dx
|
521 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
522 |
+
xhat = tl.where(mask, xhat, 0.0)
|
523 |
+
if RECOMPUTE_OUTPUT:
|
524 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
525 |
+
tl.store(Y + cols, y, mask=mask)
|
526 |
+
wdy = w * dy
|
527 |
+
dw += dy * xhat
|
528 |
+
if HAS_BIAS:
|
529 |
+
db += dy
|
530 |
+
if HAS_DY1:
|
531 |
+
wdy += w1 * dy1
|
532 |
+
dw1 += dy1 * xhat
|
533 |
+
if HAS_B1:
|
534 |
+
db1 += dy1
|
535 |
+
if not IS_RMS_NORM:
|
536 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
537 |
+
c2 = tl.sum(wdy, axis=0) / N
|
538 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
539 |
+
else:
|
540 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
541 |
+
dx = (wdy - xhat * c1) * rstd
|
542 |
+
if HAS_DRESIDUAL:
|
543 |
+
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
544 |
+
dx += dres
|
545 |
+
# Write dx
|
546 |
+
if STORE_DRESIDUAL:
|
547 |
+
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
548 |
+
if HAS_DX1:
|
549 |
+
if HAS_DROPOUT:
|
550 |
+
keep_mask = (
|
551 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
552 |
+
)
|
553 |
+
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
554 |
+
else:
|
555 |
+
dx1 = dx
|
556 |
+
tl.store(DX1 + cols, dx1, mask=mask)
|
557 |
+
if HAS_DROPOUT:
|
558 |
+
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
559 |
+
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
560 |
+
if HAS_ROWSCALE:
|
561 |
+
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
562 |
+
dx *= rowscale
|
563 |
+
tl.store(DX + cols, dx, mask=mask)
|
564 |
+
|
565 |
+
X += stride_x_row
|
566 |
+
if HAS_DRESIDUAL:
|
567 |
+
DRESIDUAL += stride_dres_row
|
568 |
+
if STORE_DRESIDUAL:
|
569 |
+
DRESIDUAL_IN += stride_dres_in_row
|
570 |
+
if RECOMPUTE_OUTPUT:
|
571 |
+
Y += stride_y_row
|
572 |
+
DY += stride_dy_row
|
573 |
+
DX += stride_dx_row
|
574 |
+
if HAS_DY1:
|
575 |
+
DY1 += stride_dy1_row
|
576 |
+
if HAS_DX1:
|
577 |
+
DX1 += stride_dx1_row
|
578 |
+
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
579 |
+
if HAS_BIAS:
|
580 |
+
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
581 |
+
if HAS_DY1:
|
582 |
+
tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
|
583 |
+
if HAS_B1:
|
584 |
+
tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
|
585 |
+
|
586 |
+
|
587 |
+
def _layer_norm_bwd(
|
588 |
+
dy,
|
589 |
+
x,
|
590 |
+
weight,
|
591 |
+
bias,
|
592 |
+
eps,
|
593 |
+
mean,
|
594 |
+
rstd,
|
595 |
+
dresidual=None,
|
596 |
+
dy1=None,
|
597 |
+
weight1=None,
|
598 |
+
bias1=None,
|
599 |
+
seeds=None,
|
600 |
+
dropout_p=0.0,
|
601 |
+
rowscale=None,
|
602 |
+
has_residual=False,
|
603 |
+
has_x1=False,
|
604 |
+
is_rms_norm=False,
|
605 |
+
x_dtype=None,
|
606 |
+
recompute_output=False,
|
607 |
+
):
|
608 |
+
M, N = x.shape
|
609 |
+
assert x.stride(-1) == 1
|
610 |
+
assert dy.stride(-1) == 1
|
611 |
+
assert dy.shape == (M, N)
|
612 |
+
if dresidual is not None:
|
613 |
+
assert dresidual.stride(-1) == 1
|
614 |
+
assert dresidual.shape == (M, N)
|
615 |
+
assert weight.shape == (N,)
|
616 |
+
assert weight.stride(-1) == 1
|
617 |
+
if bias is not None:
|
618 |
+
assert bias.stride(-1) == 1
|
619 |
+
assert bias.shape == (N,)
|
620 |
+
if dy1 is not None:
|
621 |
+
assert weight1 is not None
|
622 |
+
assert dy1.shape == dy.shape
|
623 |
+
assert dy1.stride(-1) == 1
|
624 |
+
if weight1 is not None:
|
625 |
+
assert weight1.shape == (N,)
|
626 |
+
assert weight1.stride(-1) == 1
|
627 |
+
if bias1 is not None:
|
628 |
+
assert bias1.shape == (N,)
|
629 |
+
assert bias1.stride(-1) == 1
|
630 |
+
if seeds is not None:
|
631 |
+
assert seeds.is_contiguous()
|
632 |
+
assert seeds.shape == (M if not has_x1 else M * 2,)
|
633 |
+
if rowscale is not None:
|
634 |
+
assert rowscale.is_contiguous()
|
635 |
+
assert rowscale.shape == (M,)
|
636 |
+
# allocate output
|
637 |
+
dx = (
|
638 |
+
torch.empty_like(x)
|
639 |
+
if x_dtype is None
|
640 |
+
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
641 |
+
)
|
642 |
+
dresidual_in = (
|
643 |
+
torch.empty_like(x)
|
644 |
+
if has_residual
|
645 |
+
and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
|
646 |
+
else None
|
647 |
+
)
|
648 |
+
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
|
649 |
+
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
|
650 |
+
if recompute_output:
|
651 |
+
assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
|
652 |
+
|
653 |
+
# Less than 64KB per feature: enqueue fused kernel
|
654 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
655 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
656 |
+
if N > BLOCK_N:
|
657 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
658 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
659 |
+
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
660 |
+
_db = (
|
661 |
+
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
662 |
+
if bias is not None
|
663 |
+
else None
|
664 |
+
)
|
665 |
+
_dw1 = torch.empty_like(_dw) if weight1 is not None else None
|
666 |
+
_db1 = torch.empty_like(_db) if bias1 is not None else None
|
667 |
+
rows_per_program = math.ceil(M / sm_count)
|
668 |
+
grid = (sm_count,)
|
669 |
+
with torch.cuda.device(x.device.index):
|
670 |
+
_layer_norm_bwd_kernel[grid](
|
671 |
+
x,
|
672 |
+
weight,
|
673 |
+
bias,
|
674 |
+
y,
|
675 |
+
dy,
|
676 |
+
dx,
|
677 |
+
_dw,
|
678 |
+
_db,
|
679 |
+
dresidual,
|
680 |
+
weight1,
|
681 |
+
dy1,
|
682 |
+
dx1,
|
683 |
+
_dw1,
|
684 |
+
_db1,
|
685 |
+
dresidual_in,
|
686 |
+
rowscale,
|
687 |
+
seeds,
|
688 |
+
mean,
|
689 |
+
rstd,
|
690 |
+
x.stride(0),
|
691 |
+
0 if not recompute_output else y.stride(0),
|
692 |
+
dy.stride(0),
|
693 |
+
dx.stride(0),
|
694 |
+
dresidual.stride(0) if dresidual is not None else 0,
|
695 |
+
dy1.stride(0) if dy1 is not None else 0,
|
696 |
+
dx1.stride(0) if dx1 is not None else 0,
|
697 |
+
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
698 |
+
M,
|
699 |
+
N,
|
700 |
+
eps,
|
701 |
+
dropout_p,
|
702 |
+
rows_per_program,
|
703 |
+
is_rms_norm,
|
704 |
+
BLOCK_N,
|
705 |
+
dresidual is not None,
|
706 |
+
dresidual_in is not None,
|
707 |
+
bias is not None,
|
708 |
+
dropout_p > 0.0,
|
709 |
+
)
|
710 |
+
dw = _dw.sum(0).to(weight.dtype)
|
711 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
712 |
+
dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
|
713 |
+
db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
|
714 |
+
# Don't need to compute dresidual_in separately in this case
|
715 |
+
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
|
716 |
+
dresidual_in = dx
|
717 |
+
if has_x1 and dropout_p == 0.0:
|
718 |
+
dx1 = dx
|
719 |
+
return (
|
720 |
+
(dx, dw, db, dresidual_in, dx1, dw1, db1)
|
721 |
+
if not recompute_output
|
722 |
+
else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
|
723 |
+
)
|
724 |
+
|
725 |
+
|
726 |
+
class LayerNormFn(torch.autograd.Function):
|
727 |
+
@staticmethod
|
728 |
+
def forward(
|
729 |
+
ctx,
|
730 |
+
x,
|
731 |
+
weight,
|
732 |
+
bias,
|
733 |
+
residual=None,
|
734 |
+
x1=None,
|
735 |
+
weight1=None,
|
736 |
+
bias1=None,
|
737 |
+
eps=1e-6,
|
738 |
+
dropout_p=0.0,
|
739 |
+
rowscale=None,
|
740 |
+
prenorm=False,
|
741 |
+
residual_in_fp32=False,
|
742 |
+
is_rms_norm=False,
|
743 |
+
return_dropout_mask=False,
|
744 |
+
):
|
745 |
+
x_shape_og = x.shape
|
746 |
+
# reshape input data into 2D tensor
|
747 |
+
x = x.reshape(-1, x.shape[-1])
|
748 |
+
if x.stride(-1) != 1:
|
749 |
+
x = x.contiguous()
|
750 |
+
if residual is not None:
|
751 |
+
assert residual.shape == x_shape_og
|
752 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
753 |
+
if residual.stride(-1) != 1:
|
754 |
+
residual = residual.contiguous()
|
755 |
+
if x1 is not None:
|
756 |
+
assert x1.shape == x_shape_og
|
757 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
758 |
+
x1 = x1.reshape(-1, x1.shape[-1])
|
759 |
+
if x1.stride(-1) != 1:
|
760 |
+
x1 = x1.contiguous()
|
761 |
+
weight = weight.contiguous()
|
762 |
+
if bias is not None:
|
763 |
+
bias = bias.contiguous()
|
764 |
+
if weight1 is not None:
|
765 |
+
weight1 = weight1.contiguous()
|
766 |
+
if bias1 is not None:
|
767 |
+
bias1 = bias1.contiguous()
|
768 |
+
if rowscale is not None:
|
769 |
+
rowscale = rowscale.reshape(-1).contiguous()
|
770 |
+
residual_dtype = (
|
771 |
+
residual.dtype
|
772 |
+
if residual is not None
|
773 |
+
else (torch.float32 if residual_in_fp32 else None)
|
774 |
+
)
|
775 |
+
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
|
776 |
+
x,
|
777 |
+
weight,
|
778 |
+
bias,
|
779 |
+
eps,
|
780 |
+
residual,
|
781 |
+
x1,
|
782 |
+
weight1,
|
783 |
+
bias1,
|
784 |
+
dropout_p=dropout_p,
|
785 |
+
rowscale=rowscale,
|
786 |
+
residual_dtype=residual_dtype,
|
787 |
+
is_rms_norm=is_rms_norm,
|
788 |
+
return_dropout_mask=return_dropout_mask,
|
789 |
+
)
|
790 |
+
ctx.save_for_backward(
|
791 |
+
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
792 |
+
)
|
793 |
+
ctx.x_shape_og = x_shape_og
|
794 |
+
ctx.eps = eps
|
795 |
+
ctx.dropout_p = dropout_p
|
796 |
+
ctx.is_rms_norm = is_rms_norm
|
797 |
+
ctx.has_residual = residual is not None
|
798 |
+
ctx.has_x1 = x1 is not None
|
799 |
+
ctx.prenorm = prenorm
|
800 |
+
ctx.x_dtype = x.dtype
|
801 |
+
y = y.reshape(x_shape_og)
|
802 |
+
y1 = y1.reshape(x_shape_og) if y1 is not None else None
|
803 |
+
residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
|
804 |
+
dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
|
805 |
+
dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
|
806 |
+
if not return_dropout_mask:
|
807 |
+
if weight1 is None:
|
808 |
+
return y if not prenorm else (y, residual_out)
|
809 |
+
else:
|
810 |
+
return (y, y1) if not prenorm else (y, y1, residual_out)
|
811 |
+
else:
|
812 |
+
if weight1 is None:
|
813 |
+
return (
|
814 |
+
(y, dropout_mask, dropout_mask1)
|
815 |
+
if not prenorm
|
816 |
+
else (y, residual_out, dropout_mask, dropout_mask1)
|
817 |
+
)
|
818 |
+
else:
|
819 |
+
return (
|
820 |
+
(y, y1, dropout_mask, dropout_mask1)
|
821 |
+
if not prenorm
|
822 |
+
else (y, y1, residual_out, dropout_mask, dropout_mask1)
|
823 |
+
)
|
824 |
+
|
825 |
+
@staticmethod
|
826 |
+
def backward(ctx, dy, *args):
|
827 |
+
x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
|
828 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
829 |
+
if dy.stride(-1) != 1:
|
830 |
+
dy = dy.contiguous()
|
831 |
+
assert dy.shape == x.shape
|
832 |
+
if weight1 is not None:
|
833 |
+
dy1, args = args[0], args[1:]
|
834 |
+
dy1 = dy1.reshape(-1, dy1.shape[-1])
|
835 |
+
if dy1.stride(-1) != 1:
|
836 |
+
dy1 = dy1.contiguous()
|
837 |
+
assert dy1.shape == x.shape
|
838 |
+
else:
|
839 |
+
dy1 = None
|
840 |
+
if ctx.prenorm:
|
841 |
+
dresidual = args[0]
|
842 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
843 |
+
if dresidual.stride(-1) != 1:
|
844 |
+
dresidual = dresidual.contiguous()
|
845 |
+
assert dresidual.shape == x.shape
|
846 |
+
else:
|
847 |
+
dresidual = None
|
848 |
+
dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
|
849 |
+
dy,
|
850 |
+
x,
|
851 |
+
weight,
|
852 |
+
bias,
|
853 |
+
ctx.eps,
|
854 |
+
mean,
|
855 |
+
rstd,
|
856 |
+
dresidual,
|
857 |
+
dy1,
|
858 |
+
weight1,
|
859 |
+
bias1,
|
860 |
+
seeds,
|
861 |
+
ctx.dropout_p,
|
862 |
+
rowscale,
|
863 |
+
ctx.has_residual,
|
864 |
+
ctx.has_x1,
|
865 |
+
ctx.is_rms_norm,
|
866 |
+
x_dtype=ctx.x_dtype,
|
867 |
+
)
|
868 |
+
return (
|
869 |
+
dx.reshape(ctx.x_shape_og),
|
870 |
+
dw,
|
871 |
+
db,
|
872 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
873 |
+
dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
|
874 |
+
dw1,
|
875 |
+
db1,
|
876 |
+
None,
|
877 |
+
None,
|
878 |
+
None,
|
879 |
+
None,
|
880 |
+
None,
|
881 |
+
None,
|
882 |
+
None,
|
883 |
+
)
|
884 |
+
|
885 |
+
|
886 |
+
def layer_norm_fn(
|
887 |
+
x,
|
888 |
+
weight,
|
889 |
+
bias,
|
890 |
+
residual=None,
|
891 |
+
x1=None,
|
892 |
+
weight1=None,
|
893 |
+
bias1=None,
|
894 |
+
eps=1e-6,
|
895 |
+
dropout_p=0.0,
|
896 |
+
rowscale=None,
|
897 |
+
prenorm=False,
|
898 |
+
residual_in_fp32=False,
|
899 |
+
is_rms_norm=False,
|
900 |
+
return_dropout_mask=False,
|
901 |
+
):
|
902 |
+
return LayerNormFn.apply(
|
903 |
+
x,
|
904 |
+
weight,
|
905 |
+
bias,
|
906 |
+
residual,
|
907 |
+
x1,
|
908 |
+
weight1,
|
909 |
+
bias1,
|
910 |
+
eps,
|
911 |
+
dropout_p,
|
912 |
+
rowscale,
|
913 |
+
prenorm,
|
914 |
+
residual_in_fp32,
|
915 |
+
is_rms_norm,
|
916 |
+
return_dropout_mask,
|
917 |
+
)
|
918 |
+
|
919 |
+
|
920 |
+
def rms_norm_fn(
|
921 |
+
x,
|
922 |
+
weight,
|
923 |
+
bias,
|
924 |
+
residual=None,
|
925 |
+
x1=None,
|
926 |
+
weight1=None,
|
927 |
+
bias1=None,
|
928 |
+
eps=1e-6,
|
929 |
+
dropout_p=0.0,
|
930 |
+
rowscale=None,
|
931 |
+
prenorm=False,
|
932 |
+
residual_in_fp32=False,
|
933 |
+
return_dropout_mask=False,
|
934 |
+
):
|
935 |
+
return LayerNormFn.apply(
|
936 |
+
x,
|
937 |
+
weight,
|
938 |
+
bias,
|
939 |
+
residual,
|
940 |
+
x1,
|
941 |
+
weight1,
|
942 |
+
bias1,
|
943 |
+
eps,
|
944 |
+
dropout_p,
|
945 |
+
rowscale,
|
946 |
+
prenorm,
|
947 |
+
residual_in_fp32,
|
948 |
+
True,
|
949 |
+
return_dropout_mask,
|
950 |
+
)
|
951 |
+
|
952 |
+
|
953 |
+
class RMSNorm(torch.nn.Module):
|
954 |
+
|
955 |
+
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
|
956 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
957 |
+
super().__init__()
|
958 |
+
self.eps = eps
|
959 |
+
if dropout_p > 0.0:
|
960 |
+
self.drop = torch.nn.Dropout(dropout_p)
|
961 |
+
else:
|
962 |
+
self.drop = None
|
963 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
964 |
+
self.register_parameter("bias", None)
|
965 |
+
self.reset_parameters()
|
966 |
+
|
967 |
+
def reset_parameters(self):
|
968 |
+
torch.nn.init.ones_(self.weight)
|
969 |
+
|
970 |
+
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
971 |
+
return rms_norm_fn(
|
972 |
+
x,
|
973 |
+
self.weight,
|
974 |
+
self.bias,
|
975 |
+
residual=residual,
|
976 |
+
eps=self.eps,
|
977 |
+
dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
|
978 |
+
prenorm=prenorm,
|
979 |
+
residual_in_fp32=residual_in_fp32,
|
980 |
+
)
|
981 |
+
|
982 |
+
|
983 |
+
class LayerNormLinearFn(torch.autograd.Function):
|
984 |
+
@staticmethod
|
985 |
+
@custom_fwd
|
986 |
+
def forward(
|
987 |
+
ctx,
|
988 |
+
x,
|
989 |
+
norm_weight,
|
990 |
+
norm_bias,
|
991 |
+
linear_weight,
|
992 |
+
linear_bias,
|
993 |
+
residual=None,
|
994 |
+
eps=1e-6,
|
995 |
+
prenorm=False,
|
996 |
+
residual_in_fp32=False,
|
997 |
+
is_rms_norm=False,
|
998 |
+
):
|
999 |
+
x_shape_og = x.shape
|
1000 |
+
# reshape input data into 2D tensor
|
1001 |
+
x = x.reshape(-1, x.shape[-1])
|
1002 |
+
if x.stride(-1) != 1:
|
1003 |
+
x = x.contiguous()
|
1004 |
+
if residual is not None:
|
1005 |
+
assert residual.shape == x_shape_og
|
1006 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
1007 |
+
if residual.stride(-1) != 1:
|
1008 |
+
residual = residual.contiguous()
|
1009 |
+
norm_weight = norm_weight.contiguous()
|
1010 |
+
if norm_bias is not None:
|
1011 |
+
norm_bias = norm_bias.contiguous()
|
1012 |
+
residual_dtype = (
|
1013 |
+
residual.dtype
|
1014 |
+
if residual is not None
|
1015 |
+
else (torch.float32 if residual_in_fp32 else None)
|
1016 |
+
)
|
1017 |
+
y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
|
1018 |
+
x,
|
1019 |
+
norm_weight,
|
1020 |
+
norm_bias,
|
1021 |
+
eps,
|
1022 |
+
residual,
|
1023 |
+
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
|
1024 |
+
residual_dtype=residual_dtype,
|
1025 |
+
is_rms_norm=is_rms_norm,
|
1026 |
+
)
|
1027 |
+
y = y.reshape(x_shape_og)
|
1028 |
+
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
1029 |
+
linear_weight = linear_weight.to(dtype)
|
1030 |
+
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
1031 |
+
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
1032 |
+
# We don't store y, will be recomputed in the backward pass to save memory
|
1033 |
+
ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
|
1034 |
+
ctx.x_shape_og = x_shape_og
|
1035 |
+
ctx.eps = eps
|
1036 |
+
ctx.is_rms_norm = is_rms_norm
|
1037 |
+
ctx.has_residual = residual is not None
|
1038 |
+
ctx.prenorm = prenorm
|
1039 |
+
ctx.x_dtype = x.dtype
|
1040 |
+
ctx.linear_bias_is_none = linear_bias is None
|
1041 |
+
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
1042 |
+
|
1043 |
+
@staticmethod
|
1044 |
+
@custom_bwd
|
1045 |
+
def backward(ctx, dout, *args):
|
1046 |
+
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
1047 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
1048 |
+
dy = F.linear(dout, linear_weight.t())
|
1049 |
+
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
1050 |
+
if dy.stride(-1) != 1:
|
1051 |
+
dy = dy.contiguous()
|
1052 |
+
assert dy.shape == x.shape
|
1053 |
+
if ctx.prenorm:
|
1054 |
+
dresidual = args[0]
|
1055 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
1056 |
+
if dresidual.stride(-1) != 1:
|
1057 |
+
dresidual = dresidual.contiguous()
|
1058 |
+
assert dresidual.shape == x.shape
|
1059 |
+
else:
|
1060 |
+
dresidual = None
|
1061 |
+
dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
|
1062 |
+
dy,
|
1063 |
+
x,
|
1064 |
+
norm_weight,
|
1065 |
+
norm_bias,
|
1066 |
+
ctx.eps,
|
1067 |
+
mean,
|
1068 |
+
rstd,
|
1069 |
+
dresidual=dresidual,
|
1070 |
+
has_residual=ctx.has_residual,
|
1071 |
+
is_rms_norm=ctx.is_rms_norm,
|
1072 |
+
x_dtype=ctx.x_dtype,
|
1073 |
+
recompute_output=True,
|
1074 |
+
)
|
1075 |
+
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
1076 |
+
return (
|
1077 |
+
dx.reshape(ctx.x_shape_og),
|
1078 |
+
dnorm_weight,
|
1079 |
+
dnorm_bias,
|
1080 |
+
dlinear_weight,
|
1081 |
+
dlinear_bias,
|
1082 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
1083 |
+
None,
|
1084 |
+
None,
|
1085 |
+
None,
|
1086 |
+
None,
|
1087 |
+
)
|
1088 |
+
|
1089 |
+
|
1090 |
+
def layer_norm_linear_fn(
|
1091 |
+
x,
|
1092 |
+
norm_weight,
|
1093 |
+
norm_bias,
|
1094 |
+
linear_weight,
|
1095 |
+
linear_bias,
|
1096 |
+
residual=None,
|
1097 |
+
eps=1e-6,
|
1098 |
+
prenorm=False,
|
1099 |
+
residual_in_fp32=False,
|
1100 |
+
is_rms_norm=False,
|
1101 |
+
):
|
1102 |
+
return LayerNormLinearFn.apply(
|
1103 |
+
x,
|
1104 |
+
norm_weight,
|
1105 |
+
norm_bias,
|
1106 |
+
linear_weight,
|
1107 |
+
linear_bias,
|
1108 |
+
residual,
|
1109 |
+
eps,
|
1110 |
+
prenorm,
|
1111 |
+
residual_in_fp32,
|
1112 |
+
is_rms_norm,
|
1113 |
+
)
|
mamba/build/lib/mamba_ssm/ops/triton/layernorm_gated.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
3 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
4 |
+
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
5 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
import triton
|
13 |
+
import triton.language as tl
|
14 |
+
|
15 |
+
from einops import rearrange
|
16 |
+
|
17 |
+
|
18 |
+
def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
|
19 |
+
dtype = x.dtype
|
20 |
+
N = x.shape[-1]
|
21 |
+
weight = weight.float()
|
22 |
+
bias = bias.float() if bias is not None else None
|
23 |
+
if upcast:
|
24 |
+
x = x.float()
|
25 |
+
z = z.float() if z is not None else z
|
26 |
+
if z is not None and not norm_before_gate:
|
27 |
+
x = x * F.silu(z)
|
28 |
+
if group_size is None:
|
29 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
30 |
+
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
31 |
+
else:
|
32 |
+
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
33 |
+
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
|
34 |
+
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
35 |
+
if bias is not None:
|
36 |
+
out = out + bias
|
37 |
+
if z is not None and norm_before_gate:
|
38 |
+
out *= F.silu(z)
|
39 |
+
return out.to(dtype)
|
40 |
+
|
41 |
+
|
42 |
+
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
43 |
+
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
44 |
+
@triton.jit
|
45 |
+
def _layer_norm_fwd_1pass_kernel(
|
46 |
+
X, # pointer to the input
|
47 |
+
Y, # pointer to the output
|
48 |
+
W, # pointer to the weights
|
49 |
+
B, # pointer to the biases
|
50 |
+
Z, # pointer to the other branch
|
51 |
+
Mean, # pointer to the mean
|
52 |
+
Rstd, # pointer to the 1/std
|
53 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
54 |
+
stride_y_row,
|
55 |
+
stride_z_row,
|
56 |
+
M, # number of rows in X
|
57 |
+
N, # number of columns in X
|
58 |
+
eps, # epsilon to avoid division by zero
|
59 |
+
BLOCK_N: tl.constexpr,
|
60 |
+
HAS_BIAS: tl.constexpr,
|
61 |
+
HAS_Z: tl.constexpr,
|
62 |
+
NORM_BEFORE_GATE: tl.constexpr,
|
63 |
+
IS_RMS_NORM: tl.constexpr,
|
64 |
+
):
|
65 |
+
# Map the program id to the row of X and Y it should compute.
|
66 |
+
row = tl.program_id(0)
|
67 |
+
group = tl.program_id(1)
|
68 |
+
X += row * stride_x_row + group * N
|
69 |
+
Y += row * stride_y_row + group * N
|
70 |
+
if HAS_Z:
|
71 |
+
Z += row * stride_z_row + group * N
|
72 |
+
if not IS_RMS_NORM:
|
73 |
+
Mean += group * M
|
74 |
+
Rstd += group * M
|
75 |
+
W += group * N
|
76 |
+
if HAS_BIAS:
|
77 |
+
B += group * N
|
78 |
+
# Compute mean and variance
|
79 |
+
cols = tl.arange(0, BLOCK_N)
|
80 |
+
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
81 |
+
if HAS_Z and not NORM_BEFORE_GATE:
|
82 |
+
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
83 |
+
x *= z * tl.sigmoid(z)
|
84 |
+
if not IS_RMS_NORM:
|
85 |
+
mean = tl.sum(x, axis=0) / N
|
86 |
+
tl.store(Mean + row, mean)
|
87 |
+
xbar = tl.where(cols < N, x - mean, 0.)
|
88 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
89 |
+
else:
|
90 |
+
xbar = tl.where(cols < N, x, 0.)
|
91 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
92 |
+
rstd = 1 / tl.sqrt(var + eps)
|
93 |
+
tl.store(Rstd + row, rstd)
|
94 |
+
# Normalize and apply linear transformation
|
95 |
+
mask = cols < N
|
96 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
97 |
+
if HAS_BIAS:
|
98 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
99 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
100 |
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
101 |
+
if HAS_Z and NORM_BEFORE_GATE:
|
102 |
+
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
103 |
+
y *= z * tl.sigmoid(z)
|
104 |
+
# Write output
|
105 |
+
tl.store(Y + cols, y, mask=mask)
|
106 |
+
|
107 |
+
|
108 |
+
def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):
|
109 |
+
M, N = x.shape
|
110 |
+
if group_size is None:
|
111 |
+
group_size = N
|
112 |
+
assert N % group_size == 0
|
113 |
+
ngroups = N // group_size
|
114 |
+
assert x.stride(-1) == 1
|
115 |
+
if z is not None:
|
116 |
+
assert z.stride(-1) == 1
|
117 |
+
assert z.shape == (M, N)
|
118 |
+
assert weight.shape == (N,)
|
119 |
+
assert weight.stride(-1) == 1
|
120 |
+
if bias is not None:
|
121 |
+
assert bias.stride(-1) == 1
|
122 |
+
assert bias.shape == (N,)
|
123 |
+
# allocate output
|
124 |
+
if out is not None:
|
125 |
+
assert out.shape == x.shape
|
126 |
+
else:
|
127 |
+
out = torch.empty_like(x)
|
128 |
+
assert out.stride(-1) == 1
|
129 |
+
mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
130 |
+
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
131 |
+
# Less than 64KB per feature: enqueue fused kernel
|
132 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
133 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
134 |
+
if group_size > BLOCK_N:
|
135 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
136 |
+
# heuristics for number of warps
|
137 |
+
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
138 |
+
grid = (M, ngroups)
|
139 |
+
with torch.cuda.device(x.device.index):
|
140 |
+
_layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,
|
141 |
+
x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,
|
142 |
+
M, group_size, eps,
|
143 |
+
BLOCK_N=BLOCK_N,
|
144 |
+
NORM_BEFORE_GATE=norm_before_gate,
|
145 |
+
IS_RMS_NORM=is_rms_norm,
|
146 |
+
num_warps=num_warps)
|
147 |
+
return out, mean, rstd
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
152 |
+
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
153 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
154 |
+
@triton.jit
|
155 |
+
def _layer_norm_bwd_kernel(
|
156 |
+
X, # pointer to the input
|
157 |
+
W, # pointer to the weights
|
158 |
+
B, # pointer to the biases
|
159 |
+
Z, # pointer to the other branch
|
160 |
+
Y, # pointer to the output to be recomputed
|
161 |
+
DY, # pointer to the output gradient
|
162 |
+
DX, # pointer to the input gradient
|
163 |
+
DW, # pointer to the partial sum of weights gradient
|
164 |
+
DB, # pointer to the partial sum of biases gradient
|
165 |
+
DZ, # pointer to the other branch
|
166 |
+
Mean, # pointer to the mean
|
167 |
+
Rstd, # pointer to the 1/std
|
168 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
169 |
+
stride_z_row,
|
170 |
+
stride_y_row,
|
171 |
+
stride_dy_row,
|
172 |
+
stride_dx_row,
|
173 |
+
stride_dz_row,
|
174 |
+
stride_dw_row,
|
175 |
+
stride_db_row,
|
176 |
+
M, # number of rows in X
|
177 |
+
N, # number of columns in X
|
178 |
+
eps, # epsilon to avoid division by zero
|
179 |
+
rows_per_program,
|
180 |
+
NORM_BEFORE_GATE: tl.constexpr,
|
181 |
+
IS_RMS_NORM: tl.constexpr,
|
182 |
+
HAS_BIAS: tl.constexpr,
|
183 |
+
HAS_Z: tl.constexpr,
|
184 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
185 |
+
BLOCK_N: tl.constexpr,
|
186 |
+
):
|
187 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
188 |
+
row_block_id = tl.program_id(0)
|
189 |
+
group = tl.program_id(1)
|
190 |
+
row_start = row_block_id * rows_per_program
|
191 |
+
cols = tl.arange(0, BLOCK_N)
|
192 |
+
mask = cols < N
|
193 |
+
X += row_start * stride_x_row + group * N
|
194 |
+
if HAS_Z:
|
195 |
+
Z += row_start * stride_z_row + group * N
|
196 |
+
DZ += row_start * stride_dz_row + group * N
|
197 |
+
DY += row_start * stride_dy_row + group * N
|
198 |
+
DX += row_start * stride_dx_row + group * N
|
199 |
+
if RECOMPUTE_OUTPUT:
|
200 |
+
Y += row_start * stride_y_row + group * N
|
201 |
+
if not IS_RMS_NORM:
|
202 |
+
Mean += group * M
|
203 |
+
Rstd += group * M
|
204 |
+
W += group * N
|
205 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
206 |
+
if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
|
207 |
+
B += group * N
|
208 |
+
b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
|
209 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
210 |
+
if HAS_BIAS:
|
211 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
212 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
213 |
+
for row in range(row_start, row_end):
|
214 |
+
# Load data to SRAM
|
215 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
216 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
217 |
+
if not IS_RMS_NORM:
|
218 |
+
mean = tl.load(Mean + row)
|
219 |
+
if HAS_Z and not NORM_BEFORE_GATE:
|
220 |
+
z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
|
221 |
+
x_og = x
|
222 |
+
x = x_og * z * tl.sigmoid(z)
|
223 |
+
rstd = tl.load(Rstd + row)
|
224 |
+
# Compute dx
|
225 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
226 |
+
xhat = tl.where(mask, xhat, 0.)
|
227 |
+
if HAS_Z and NORM_BEFORE_GATE:
|
228 |
+
z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
|
229 |
+
z_sigmoid = tl.sigmoid(z)
|
230 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
231 |
+
if RECOMPUTE_OUTPUT:
|
232 |
+
tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
|
233 |
+
dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
|
234 |
+
tl.store(DZ + cols, dz, mask=mask)
|
235 |
+
dy *= z * z_sigmoid
|
236 |
+
else:
|
237 |
+
if RECOMPUTE_OUTPUT:
|
238 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
239 |
+
tl.store(Y + cols, y, mask=mask)
|
240 |
+
wdy = w * dy
|
241 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
242 |
+
if not IS_RMS_NORM:
|
243 |
+
c2 = tl.sum(wdy, axis=0) / N
|
244 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
245 |
+
else:
|
246 |
+
dx = (wdy - xhat * c1) * rstd
|
247 |
+
dw += dy * xhat
|
248 |
+
if HAS_BIAS:
|
249 |
+
db += dy
|
250 |
+
if HAS_Z and not NORM_BEFORE_GATE:
|
251 |
+
z_sigmoid = tl.sigmoid(z)
|
252 |
+
dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
|
253 |
+
tl.store(DZ + cols, dz, mask=mask)
|
254 |
+
dx *= z * z_sigmoid
|
255 |
+
# Write dx
|
256 |
+
tl.store(DX + cols, dx, mask=mask)
|
257 |
+
|
258 |
+
X += stride_x_row
|
259 |
+
if HAS_Z:
|
260 |
+
Z += stride_z_row
|
261 |
+
DZ += stride_dz_row
|
262 |
+
if RECOMPUTE_OUTPUT:
|
263 |
+
Y += stride_y_row
|
264 |
+
DY += stride_dy_row
|
265 |
+
DX += stride_dx_row
|
266 |
+
tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
|
267 |
+
if HAS_BIAS:
|
268 |
+
tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)
|
269 |
+
|
270 |
+
|
271 |
+
def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,
|
272 |
+
norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):
|
273 |
+
M, N = x.shape
|
274 |
+
if group_size is None:
|
275 |
+
group_size = N
|
276 |
+
assert N % group_size == 0
|
277 |
+
ngroups = N // group_size
|
278 |
+
assert x.stride(-1) == 1
|
279 |
+
assert dy.stride(-1) == 1
|
280 |
+
assert dy.shape == (M, N)
|
281 |
+
if z is not None:
|
282 |
+
assert z.stride(-1) == 1
|
283 |
+
assert z.shape == (M, N)
|
284 |
+
assert weight.shape == (N,)
|
285 |
+
assert weight.stride(-1) == 1
|
286 |
+
if bias is not None:
|
287 |
+
assert bias.stride(-1) == 1
|
288 |
+
assert bias.shape == (N,)
|
289 |
+
# allocate output
|
290 |
+
dx = torch.empty_like(x)
|
291 |
+
if dz is not None:
|
292 |
+
assert z is not None
|
293 |
+
assert dz.shape == z.shape
|
294 |
+
assert dz.stride(-1) == 1
|
295 |
+
else:
|
296 |
+
dz = torch.empty_like(z) if z is not None else None
|
297 |
+
if recompute_output:
|
298 |
+
if out is None:
|
299 |
+
out = torch.empty_like(x)
|
300 |
+
assert out.shape == x.shape
|
301 |
+
|
302 |
+
# Less than 64KB per feature: enqueue fused kernel
|
303 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
304 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
305 |
+
if group_size > BLOCK_N:
|
306 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
307 |
+
# heuristics for number of warps
|
308 |
+
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
309 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
310 |
+
# If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
|
311 |
+
# would limit the occupancy.
|
312 |
+
nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
|
313 |
+
_dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
|
314 |
+
_db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
|
315 |
+
rows_per_program = math.ceil(M / nrow_groups)
|
316 |
+
grid = (nrow_groups, ngroups)
|
317 |
+
with torch.cuda.device(x.device.index):
|
318 |
+
_layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,
|
319 |
+
dy, dx, _dw, _db, dz, mean, rstd,
|
320 |
+
x.stride(0),
|
321 |
+
z.stride(0) if z is not None else 0,
|
322 |
+
0 if not recompute_output else out.stride(0),
|
323 |
+
dy.stride(0), dx.stride(0),
|
324 |
+
dz.stride(0) if dz is not None else 0,
|
325 |
+
_dw.stride(0),
|
326 |
+
_db.stride(0) if _db is not None else 0,
|
327 |
+
M, group_size, eps,
|
328 |
+
rows_per_program,
|
329 |
+
BLOCK_N=BLOCK_N,
|
330 |
+
NORM_BEFORE_GATE=norm_before_gate,
|
331 |
+
IS_RMS_NORM=is_rms_norm,
|
332 |
+
num_warps=num_warps)
|
333 |
+
dw = _dw.sum(0).to(weight.dtype)
|
334 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
335 |
+
return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
|
336 |
+
|
337 |
+
|
338 |
+
class LayerNormFn(torch.autograd.Function):
|
339 |
+
|
340 |
+
@staticmethod
|
341 |
+
def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,
|
342 |
+
is_rms_norm=False):
|
343 |
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
344 |
+
"""
|
345 |
+
|
346 |
+
x_shape_og = x.shape
|
347 |
+
# reshape input data into 2D tensor
|
348 |
+
x = x.reshape(-1, x.shape[-1])
|
349 |
+
if x.stride(-1) != 1:
|
350 |
+
x = x.contiguous()
|
351 |
+
if z is not None:
|
352 |
+
assert z.shape == x_shape_og
|
353 |
+
z = z.reshape(-1, z.shape[-1])
|
354 |
+
if z.stride(-1) != 1:
|
355 |
+
z = z.contiguous()
|
356 |
+
weight = weight.contiguous()
|
357 |
+
if bias is not None:
|
358 |
+
bias = bias.contiguous()
|
359 |
+
y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)
|
360 |
+
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
361 |
+
ctx.x_shape_og = x_shape_og
|
362 |
+
ctx.eps = eps
|
363 |
+
ctx.group_size = group_size
|
364 |
+
ctx.norm_before_gate = norm_before_gate
|
365 |
+
ctx.is_rms_norm = is_rms_norm
|
366 |
+
return y.reshape(x_shape_og)
|
367 |
+
|
368 |
+
@staticmethod
|
369 |
+
def backward(ctx, dy):
|
370 |
+
x, weight, bias, mean, rstd, z = ctx.saved_tensors
|
371 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
372 |
+
if dy.stride(-1) != 1:
|
373 |
+
dy = dy.contiguous()
|
374 |
+
assert dy.shape == x.shape
|
375 |
+
dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size,
|
376 |
+
ctx.norm_before_gate, ctx.is_rms_norm)
|
377 |
+
return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None
|
378 |
+
|
379 |
+
|
380 |
+
def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
|
381 |
+
return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)
|
382 |
+
|
383 |
+
|
384 |
+
def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):
|
385 |
+
return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)
|
386 |
+
|
387 |
+
|
388 |
+
class LayerNorm(torch.nn.Module):
|
389 |
+
|
390 |
+
def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
|
391 |
+
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
392 |
+
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
393 |
+
"""
|
394 |
+
|
395 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
396 |
+
super().__init__()
|
397 |
+
self.eps = eps
|
398 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
399 |
+
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
400 |
+
self.group_size = group_size
|
401 |
+
self.norm_before_gate = norm_before_gate
|
402 |
+
self.reset_parameters()
|
403 |
+
|
404 |
+
def reset_parameters(self):
|
405 |
+
torch.nn.init.ones_(self.weight)
|
406 |
+
torch.nn.init.zeros_(self.bias)
|
407 |
+
|
408 |
+
def forward(self, x, z=None):
|
409 |
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
410 |
+
"""
|
411 |
+
return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,
|
412 |
+
norm_before_gate=self.norm_before_gate)
|
413 |
+
|
414 |
+
|
415 |
+
class RMSNorm(torch.nn.Module):
|
416 |
+
|
417 |
+
def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
|
418 |
+
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
419 |
+
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
420 |
+
"""
|
421 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
422 |
+
super().__init__()
|
423 |
+
self.eps = eps
|
424 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
425 |
+
self.register_parameter("bias", None)
|
426 |
+
self.group_size = group_size
|
427 |
+
self.norm_before_gate = norm_before_gate
|
428 |
+
self.reset_parameters()
|
429 |
+
|
430 |
+
def reset_parameters(self):
|
431 |
+
torch.nn.init.ones_(self.weight)
|
432 |
+
|
433 |
+
def forward(self, x, z=None):
|
434 |
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
435 |
+
"""
|
436 |
+
return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
|
437 |
+
norm_before_gate=self.norm_before_gate)
|
mamba/build/lib/mamba_ssm/ops/triton/selective_state_update.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
from mamba_ssm.ops.triton.softplus import softplus
|
16 |
+
|
17 |
+
|
18 |
+
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
19 |
+
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
20 |
+
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
21 |
+
@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
22 |
+
@triton.jit
|
23 |
+
def _selective_scan_update_kernel(
|
24 |
+
# Pointers to matrices
|
25 |
+
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
|
26 |
+
# Matrix dimensions
|
27 |
+
batch, nheads, dim, dstate, nheads_ngroups_ratio,
|
28 |
+
# Strides
|
29 |
+
stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
|
30 |
+
stride_x_batch, stride_x_head, stride_x_dim,
|
31 |
+
stride_dt_batch, stride_dt_head, stride_dt_dim,
|
32 |
+
stride_dt_bias_head, stride_dt_bias_dim,
|
33 |
+
stride_A_head, stride_A_dim, stride_A_dstate,
|
34 |
+
stride_B_batch, stride_B_group, stride_B_dstate,
|
35 |
+
stride_C_batch, stride_C_group, stride_C_dstate,
|
36 |
+
stride_D_head, stride_D_dim,
|
37 |
+
stride_z_batch, stride_z_head, stride_z_dim,
|
38 |
+
stride_out_batch, stride_out_head, stride_out_dim,
|
39 |
+
# Meta-parameters
|
40 |
+
DT_SOFTPLUS: tl.constexpr,
|
41 |
+
TIE_HDIM: tl.constexpr,
|
42 |
+
BLOCK_SIZE_M: tl.constexpr,
|
43 |
+
HAS_DT_BIAS: tl.constexpr,
|
44 |
+
HAS_D: tl.constexpr,
|
45 |
+
HAS_Z: tl.constexpr,
|
46 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
47 |
+
):
|
48 |
+
pid_m = tl.program_id(axis=0)
|
49 |
+
pid_b = tl.program_id(axis=1)
|
50 |
+
pid_h = tl.program_id(axis=2)
|
51 |
+
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
52 |
+
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
53 |
+
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
54 |
+
if HAS_DT_BIAS:
|
55 |
+
dt_bias_ptr += pid_h * stride_dt_bias_head
|
56 |
+
A_ptr += pid_h * stride_A_head
|
57 |
+
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
58 |
+
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
59 |
+
if HAS_Z:
|
60 |
+
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
61 |
+
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
62 |
+
|
63 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
64 |
+
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
65 |
+
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
|
66 |
+
x_ptrs = x_ptr + offs_m * stride_x_dim
|
67 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
68 |
+
if HAS_DT_BIAS:
|
69 |
+
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
70 |
+
if HAS_D:
|
71 |
+
D_ptr += pid_h * stride_D_head
|
72 |
+
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
|
73 |
+
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
74 |
+
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
75 |
+
if HAS_D:
|
76 |
+
D_ptrs = D_ptr + offs_m * stride_D_dim
|
77 |
+
if HAS_Z:
|
78 |
+
z_ptrs = z_ptr + offs_m * stride_z_dim
|
79 |
+
out_ptrs = out_ptr + offs_m * stride_out_dim
|
80 |
+
|
81 |
+
state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
|
82 |
+
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
83 |
+
if not TIE_HDIM:
|
84 |
+
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
85 |
+
if HAS_DT_BIAS:
|
86 |
+
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
87 |
+
if DT_SOFTPLUS:
|
88 |
+
dt = softplus(dt)
|
89 |
+
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
|
90 |
+
dA = tl.exp(A * dt[:, None])
|
91 |
+
else:
|
92 |
+
dt = tl.load(dt_ptr).to(tl.float32)
|
93 |
+
if HAS_DT_BIAS:
|
94 |
+
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
95 |
+
if DT_SOFTPLUS:
|
96 |
+
dt = softplus(dt)
|
97 |
+
A = tl.load(A_ptr).to(tl.float32)
|
98 |
+
dA = tl.exp(A * dt) # scalar, not a matrix
|
99 |
+
|
100 |
+
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
101 |
+
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
102 |
+
if HAS_D:
|
103 |
+
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
104 |
+
if HAS_Z:
|
105 |
+
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
106 |
+
|
107 |
+
if not TIE_HDIM:
|
108 |
+
dB = B[None, :] * dt[:, None]
|
109 |
+
else:
|
110 |
+
dB = B * dt # vector of size (dstate,)
|
111 |
+
state = state * dA + dB * x[:, None]
|
112 |
+
tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
|
113 |
+
out = tl.sum(state * C[None, :], axis=1)
|
114 |
+
if HAS_D:
|
115 |
+
out += x * D
|
116 |
+
if HAS_Z:
|
117 |
+
out *= z * tl.sigmoid(z)
|
118 |
+
tl.store(out_ptrs, out, mask=offs_m < dim)
|
119 |
+
|
120 |
+
|
121 |
+
def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
|
122 |
+
"""
|
123 |
+
Argument:
|
124 |
+
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
125 |
+
x: (batch, dim) or (batch, nheads, dim)
|
126 |
+
dt: (batch, dim) or (batch, nheads, dim)
|
127 |
+
A: (dim, dstate) or (nheads, dim, dstate)
|
128 |
+
B: (batch, dstate) or (batch, ngroups, dstate)
|
129 |
+
C: (batch, dstate) or (batch, ngroups, dstate)
|
130 |
+
D: (dim,) or (nheads, dim)
|
131 |
+
z: (batch, dim) or (batch, nheads, dim)
|
132 |
+
dt_bias: (dim,) or (nheads, dim)
|
133 |
+
Return:
|
134 |
+
out: (batch, dim) or (batch, nheads, dim)
|
135 |
+
"""
|
136 |
+
has_heads = state.dim() > 3
|
137 |
+
if state.dim() == 3:
|
138 |
+
state = state.unsqueeze(1)
|
139 |
+
if x.dim() == 2:
|
140 |
+
x = x.unsqueeze(1)
|
141 |
+
if dt.dim() == 2:
|
142 |
+
dt = dt.unsqueeze(1)
|
143 |
+
if A.dim() == 2:
|
144 |
+
A = A.unsqueeze(0)
|
145 |
+
if B.dim() == 2:
|
146 |
+
B = B.unsqueeze(1)
|
147 |
+
if C.dim() == 2:
|
148 |
+
C = C.unsqueeze(1)
|
149 |
+
if D is not None and D.dim() == 1:
|
150 |
+
D = D.unsqueeze(0)
|
151 |
+
if z is not None and z.dim() == 2:
|
152 |
+
z = z.unsqueeze(1)
|
153 |
+
if dt_bias is not None and dt_bias.dim() == 1:
|
154 |
+
dt_bias = dt_bias.unsqueeze(0)
|
155 |
+
batch, nheads, dim, dstate = state.shape
|
156 |
+
assert x.shape == (batch, nheads, dim)
|
157 |
+
assert dt.shape == x.shape
|
158 |
+
assert A.shape == (nheads, dim, dstate)
|
159 |
+
ngroups = B.shape[1]
|
160 |
+
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
161 |
+
assert B.shape == (batch, ngroups, dstate)
|
162 |
+
assert C.shape == B.shape
|
163 |
+
if D is not None:
|
164 |
+
assert D.shape == (nheads, dim)
|
165 |
+
if z is not None:
|
166 |
+
assert z.shape == x.shape
|
167 |
+
if dt_bias is not None:
|
168 |
+
assert dt_bias.shape == (nheads, dim)
|
169 |
+
out = torch.empty_like(x)
|
170 |
+
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
171 |
+
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
|
172 |
+
# We don't want autotune since it will overwrite the state
|
173 |
+
# We instead tune by hand.
|
174 |
+
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
|
175 |
+
else ((16, 4) if dstate <= 32 else
|
176 |
+
((8, 4) if dstate <= 64 else
|
177 |
+
((4, 4) if dstate <= 128 else
|
178 |
+
((4, 8))))))
|
179 |
+
tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
|
180 |
+
with torch.cuda.device(x.device.index):
|
181 |
+
_selective_scan_update_kernel[grid](
|
182 |
+
state, x, dt, dt_bias, A, B, C, D, z, out,
|
183 |
+
batch, nheads, dim, dstate, nheads // ngroups,
|
184 |
+
state.stride(0), state.stride(1), state.stride(2), state.stride(3),
|
185 |
+
x.stride(0), x.stride(1), x.stride(2),
|
186 |
+
dt.stride(0), dt.stride(1), dt.stride(2),
|
187 |
+
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
188 |
+
A.stride(0), A.stride(1), A.stride(2),
|
189 |
+
B.stride(0), B.stride(1), B.stride(2),
|
190 |
+
C.stride(0), C.stride(1), C.stride(2),
|
191 |
+
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
192 |
+
z_strides[0], z_strides[1], z_strides[2],
|
193 |
+
out.stride(0), out.stride(1), out.stride(2),
|
194 |
+
dt_softplus,
|
195 |
+
tie_hdim,
|
196 |
+
BLOCK_SIZE_M,
|
197 |
+
num_warps=num_warps,
|
198 |
+
)
|
199 |
+
if not has_heads:
|
200 |
+
out = out.squeeze(1)
|
201 |
+
return out
|
202 |
+
|
203 |
+
|
204 |
+
def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
|
205 |
+
"""
|
206 |
+
Argument:
|
207 |
+
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
208 |
+
x: (batch, dim) or (batch, nheads, dim)
|
209 |
+
dt: (batch, dim) or (batch, nheads, dim)
|
210 |
+
A: (dim, dstate) or (nheads, dim, dstate)
|
211 |
+
B: (batch, dstate) or (batch, ngroups, dstate)
|
212 |
+
C: (batch, dstate) or (batch, ngroups, dstate)
|
213 |
+
D: (dim,) or (nheads, dim)
|
214 |
+
z: (batch, dim) or (batch, nheads, dim)
|
215 |
+
dt_bias: (dim,) or (nheads, dim)
|
216 |
+
Return:
|
217 |
+
out: (batch, dim) or (batch, nheads, dim)
|
218 |
+
"""
|
219 |
+
has_heads = state.dim() > 3
|
220 |
+
if state.dim() == 3:
|
221 |
+
state = state.unsqueeze(1)
|
222 |
+
if x.dim() == 2:
|
223 |
+
x = x.unsqueeze(1)
|
224 |
+
if dt.dim() == 2:
|
225 |
+
dt = dt.unsqueeze(1)
|
226 |
+
if A.dim() == 2:
|
227 |
+
A = A.unsqueeze(0)
|
228 |
+
if B.dim() == 2:
|
229 |
+
B = B.unsqueeze(1)
|
230 |
+
if C.dim() == 2:
|
231 |
+
C = C.unsqueeze(1)
|
232 |
+
if D is not None and D.dim() == 1:
|
233 |
+
D = D.unsqueeze(0)
|
234 |
+
if z is not None and z.dim() == 2:
|
235 |
+
z = z.unsqueeze(1)
|
236 |
+
if dt_bias is not None and dt_bias.dim() == 1:
|
237 |
+
dt_bias = dt_bias.unsqueeze(0)
|
238 |
+
batch, nheads, dim, dstate = state.shape
|
239 |
+
assert x.shape == (batch, nheads, dim)
|
240 |
+
assert dt.shape == x.shape
|
241 |
+
assert A.shape == (nheads, dim, dstate)
|
242 |
+
ngroups = B.shape[1]
|
243 |
+
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
244 |
+
assert B.shape == (batch, ngroups, dstate)
|
245 |
+
assert C.shape == B.shape
|
246 |
+
if D is not None:
|
247 |
+
assert D.shape == (nheads, dim)
|
248 |
+
if z is not None:
|
249 |
+
assert z.shape == x.shape
|
250 |
+
if dt_bias is not None:
|
251 |
+
assert dt_bias.shape == (nheads, dim)
|
252 |
+
dt = dt + dt_bias
|
253 |
+
dt = F.softplus(dt) if dt_softplus else dt
|
254 |
+
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
|
255 |
+
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
256 |
+
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
257 |
+
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
|
258 |
+
state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
|
259 |
+
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
260 |
+
if D is not None:
|
261 |
+
out += (x * D).to(out.dtype)
|
262 |
+
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
263 |
+
if not has_heads:
|
264 |
+
out = out.squeeze(1)
|
265 |
+
return out
|
mamba/build/lib/mamba_ssm/ops/triton/softplus.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import triton
|
2 |
+
import triton.language as tl
|
3 |
+
from packaging import version
|
4 |
+
|
5 |
+
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
|
6 |
+
|
7 |
+
|
8 |
+
if TRITON3:
|
9 |
+
@triton.jit
|
10 |
+
def softplus(dt):
|
11 |
+
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
|
12 |
+
return dt
|
13 |
+
else:
|
14 |
+
@triton.jit
|
15 |
+
def softplus(dt):
|
16 |
+
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
17 |
+
return dt
|
mamba/build/lib/mamba_ssm/ops/triton/ssd_bmm.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or 2.2.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
|
16 |
+
def init_to_zero(names):
|
17 |
+
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
|
18 |
+
|
19 |
+
|
20 |
+
@triton.autotune(
|
21 |
+
configs=[
|
22 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
|
23 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
24 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
25 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
26 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
27 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
28 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
29 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
30 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
|
31 |
+
],
|
32 |
+
key=['chunk_size', 'K', 'IS_CAUSAL'],
|
33 |
+
)
|
34 |
+
@triton.jit
|
35 |
+
def _bmm_chunk_fwd_kernel(
|
36 |
+
# Pointers to matrices
|
37 |
+
a_ptr, b_ptr, out_ptr, seq_idx_ptr,
|
38 |
+
# Matrix dimensions
|
39 |
+
seqlen, chunk_size, K, ngroups,
|
40 |
+
stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
|
41 |
+
stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
|
42 |
+
stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
|
43 |
+
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
44 |
+
# Meta-parameters
|
45 |
+
IS_CAUSAL: tl.constexpr,
|
46 |
+
dot_dtype: tl.constexpr,
|
47 |
+
HAS_SEQ_IDX: tl.constexpr,
|
48 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
49 |
+
):
|
50 |
+
pid_b = tl.program_id(axis=1)
|
51 |
+
pid_ch = tl.program_id(axis=2)
|
52 |
+
pid_c = pid_ch // ngroups
|
53 |
+
pid_h = pid_ch - pid_c * ngroups
|
54 |
+
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
55 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
56 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
57 |
+
if IS_CAUSAL:
|
58 |
+
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
59 |
+
return
|
60 |
+
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
61 |
+
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
|
62 |
+
if HAS_SEQ_IDX:
|
63 |
+
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
64 |
+
|
65 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
66 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
67 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
68 |
+
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
|
69 |
+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
|
70 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
71 |
+
|
72 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
73 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
74 |
+
a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
|
75 |
+
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
|
76 |
+
acc += tl.dot(a, b)
|
77 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
78 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
79 |
+
|
80 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
81 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
82 |
+
if HAS_SEQ_IDX:
|
83 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
84 |
+
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
85 |
+
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
|
86 |
+
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
87 |
+
out = acc.to(out_ptr.dtype.element_ty)
|
88 |
+
|
89 |
+
out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
|
90 |
+
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
|
91 |
+
tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
|
92 |
+
|
93 |
+
|
94 |
+
@triton.autotune(
|
95 |
+
configs=[
|
96 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
|
97 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
98 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
99 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
100 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
101 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
102 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
|
103 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
|
104 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
|
105 |
+
],
|
106 |
+
key=['chunk_size', 'K'],
|
107 |
+
)
|
108 |
+
@triton.jit
|
109 |
+
def _bmm_chunk_bwd_kernel(
|
110 |
+
# Pointers to matrices
|
111 |
+
a_ptr, dout_ptr, db_ptr, res_ptr,
|
112 |
+
# Matrix dimensions
|
113 |
+
seqlen, chunk_size, K, ngroups,
|
114 |
+
stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
|
115 |
+
stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,
|
116 |
+
stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,
|
117 |
+
stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,
|
118 |
+
# Meta-parameters
|
119 |
+
dot_dtype: tl.constexpr,
|
120 |
+
HAS_RESIDUAL: tl.constexpr,
|
121 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,
|
122 |
+
):
|
123 |
+
pid_b = tl.program_id(axis=1)
|
124 |
+
pid_ch = tl.program_id(axis=2)
|
125 |
+
pid_c = pid_ch // ngroups
|
126 |
+
pid_h = pid_ch - pid_c * ngroups
|
127 |
+
num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)
|
128 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
129 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
130 |
+
|
131 |
+
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
132 |
+
dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head
|
133 |
+
|
134 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
135 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
136 |
+
offs_cs = tl.arange(0, BLOCK_SIZE_CS)
|
137 |
+
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)
|
138 |
+
a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)
|
139 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
140 |
+
|
141 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
142 |
+
for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):
|
143 |
+
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)
|
144 |
+
a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)
|
145 |
+
acc += tl.dot(dout, a)
|
146 |
+
dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m
|
147 |
+
a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen
|
148 |
+
|
149 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
150 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
151 |
+
if HAS_RESIDUAL:
|
152 |
+
res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head
|
153 |
+
res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)
|
154 |
+
res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)
|
155 |
+
acc += res
|
156 |
+
db = acc.to(db_ptr.dtype.element_ty)
|
157 |
+
|
158 |
+
db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head
|
159 |
+
db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)
|
160 |
+
tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))
|
161 |
+
|
162 |
+
|
163 |
+
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
|
164 |
+
"""
|
165 |
+
Argument:
|
166 |
+
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
167 |
+
b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
168 |
+
seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
169 |
+
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
170 |
+
guaranteed to be correct.
|
171 |
+
Return:
|
172 |
+
out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
|
173 |
+
"""
|
174 |
+
# Check constraints.
|
175 |
+
has_groups = a.dim() == 4
|
176 |
+
if not has_groups:
|
177 |
+
batch, seqlen, k = a.shape
|
178 |
+
else:
|
179 |
+
batch, seqlen, ngroups, k = a.shape
|
180 |
+
assert b.shape == a.shape
|
181 |
+
if seq_idx is not None:
|
182 |
+
assert seq_idx.shape == (batch, seqlen)
|
183 |
+
if a.stride(-1) != 1 and a.stride(1) != 1:
|
184 |
+
a = a.contiguous()
|
185 |
+
if b.stride(-1) != 1 and b.stride(1) != 1:
|
186 |
+
b = b.contiguous()
|
187 |
+
nchunks = math.ceil(seqlen / chunk_size)
|
188 |
+
# Allocates output.
|
189 |
+
out_dtype = a.dtype if output_dtype is None else output_dtype
|
190 |
+
out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
|
191 |
+
device=a.device, dtype=out_dtype)
|
192 |
+
dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
|
193 |
+
(tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
|
194 |
+
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
|
195 |
+
batch, nchunks if not has_groups else nchunks * ngroups)
|
196 |
+
with torch.cuda.device(a.device.index):
|
197 |
+
_bmm_chunk_fwd_kernel[grid](
|
198 |
+
a, b, out, seq_idx,
|
199 |
+
seqlen, chunk_size, k, ngroups if has_groups else 1,
|
200 |
+
a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
|
201 |
+
b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
|
202 |
+
out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
|
203 |
+
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
204 |
+
causal,
|
205 |
+
dot_dtype,
|
206 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
207 |
+
)
|
208 |
+
return out
|
209 |
+
|
210 |
+
|
211 |
+
def _bmm_chunk_bwd(a, dout, residual=None, out=None):
|
212 |
+
"""
|
213 |
+
Argument:
|
214 |
+
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
215 |
+
dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
|
216 |
+
residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
217 |
+
Return:
|
218 |
+
out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
219 |
+
|
220 |
+
If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be
|
221 |
+
zeroed out before calling this function.
|
222 |
+
"""
|
223 |
+
# Check constraints.
|
224 |
+
has_groups = a.dim() == 4
|
225 |
+
if not has_groups:
|
226 |
+
batch, seqlen, k = a.shape
|
227 |
+
else:
|
228 |
+
batch, seqlen, ngroups, k = a.shape
|
229 |
+
nchunks, chunk_size = dout.shape[1], dout.shape[-1]
|
230 |
+
if a.stride(-1) != 1 and a.stride(-2) != 1:
|
231 |
+
a = a.contiguous()
|
232 |
+
if dout.stride(-1) != 1 and dout.stride(-2) != 1:
|
233 |
+
dout = dout.contiguous()
|
234 |
+
if residual is not None:
|
235 |
+
assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)
|
236 |
+
if residual.stride(-1) != 1 and residual.stride(1) != 1:
|
237 |
+
residual = residual.contiguous()
|
238 |
+
# Allocates output.
|
239 |
+
if out is not None:
|
240 |
+
assert out.shape == a.shape
|
241 |
+
assert out.stride(-1) == 1 or out.stride(1) == 1
|
242 |
+
else:
|
243 |
+
out = torch.empty_like(a)
|
244 |
+
dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else
|
245 |
+
(tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))
|
246 |
+
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,
|
247 |
+
nchunks if not has_groups else nchunks * ngroups)
|
248 |
+
residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),
|
249 |
+
residual.stride(-1))
|
250 |
+
if residual is not None else (0, 0, 0, 0))
|
251 |
+
with torch.cuda.device(a.device.index):
|
252 |
+
_bmm_chunk_bwd_kernel[grid](
|
253 |
+
a, dout, out, residual,
|
254 |
+
seqlen, chunk_size, k, ngroups if has_groups else 1,
|
255 |
+
a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
|
256 |
+
dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),
|
257 |
+
out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),
|
258 |
+
residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],
|
259 |
+
dot_dtype,
|
260 |
+
HAS_RESIDUAL=residual is not None,
|
261 |
+
)
|
262 |
+
return out
|
mamba/build/lib/mamba_ssm/ops/triton/ssd_chunk_scan.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
mamba/build/lib/mamba_ssm/ops/triton/ssd_chunk_state.py
ADDED
@@ -0,0 +1,988 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or 2.2.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
from mamba_ssm.ops.triton.softplus import softplus
|
16 |
+
|
17 |
+
|
18 |
+
def init_to_zero(names):
|
19 |
+
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
|
20 |
+
|
21 |
+
@triton.autotune(
|
22 |
+
configs=[
|
23 |
+
triton.Config({'BLOCK_SIZE_H': 1}),
|
24 |
+
triton.Config({'BLOCK_SIZE_H': 2}),
|
25 |
+
triton.Config({'BLOCK_SIZE_H': 4}),
|
26 |
+
triton.Config({'BLOCK_SIZE_H': 8}),
|
27 |
+
triton.Config({'BLOCK_SIZE_H': 16}),
|
28 |
+
triton.Config({'BLOCK_SIZE_H': 32}),
|
29 |
+
triton.Config({'BLOCK_SIZE_H': 64}),
|
30 |
+
],
|
31 |
+
key=['chunk_size', 'nheads'],
|
32 |
+
)
|
33 |
+
@triton.jit
|
34 |
+
def _chunk_cumsum_fwd_kernel(
|
35 |
+
# Pointers to matrices
|
36 |
+
dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,
|
37 |
+
# Matrix dimension
|
38 |
+
batch, seqlen, nheads, chunk_size,
|
39 |
+
dt_min, dt_max,
|
40 |
+
# Strides
|
41 |
+
stride_dt_batch, stride_dt_seqlen, stride_dt_head,
|
42 |
+
stride_A_head,
|
43 |
+
stride_dt_bias_head,
|
44 |
+
stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,
|
45 |
+
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
46 |
+
# Meta-parameters
|
47 |
+
DT_SOFTPLUS: tl.constexpr,
|
48 |
+
HAS_DT_BIAS: tl.constexpr,
|
49 |
+
BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
|
50 |
+
):
|
51 |
+
pid_b = tl.program_id(axis=0)
|
52 |
+
pid_c = tl.program_id(axis=1)
|
53 |
+
pid_h = tl.program_id(axis=2)
|
54 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
55 |
+
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
|
56 |
+
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
|
57 |
+
|
58 |
+
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
59 |
+
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
60 |
+
dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
|
61 |
+
A_ptrs = A_ptr + offs_h * stride_A_head
|
62 |
+
dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)
|
63 |
+
dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)
|
64 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
65 |
+
|
66 |
+
dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
|
67 |
+
if HAS_DT_BIAS:
|
68 |
+
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
69 |
+
dt += dt_bias[:, None]
|
70 |
+
if DT_SOFTPLUS:
|
71 |
+
dt = softplus(dt)
|
72 |
+
# As of Triton 2.2.0, tl.clamp is not available yet
|
73 |
+
# dt = tl.clamp(dt, dt_min, dt_max)
|
74 |
+
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
75 |
+
dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
|
76 |
+
tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
|
77 |
+
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
78 |
+
dA = dt * A[:, None]
|
79 |
+
dA_cs = tl.cumsum(dA, axis=1)
|
80 |
+
tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
|
81 |
+
|
82 |
+
|
83 |
+
@triton.autotune(
|
84 |
+
configs=[
|
85 |
+
triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
86 |
+
triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
87 |
+
triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
88 |
+
triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
89 |
+
triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
90 |
+
triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
91 |
+
triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
92 |
+
],
|
93 |
+
key=['chunk_size', 'nheads'],
|
94 |
+
)
|
95 |
+
@triton.jit
|
96 |
+
def _chunk_cumsum_bwd_kernel(
|
97 |
+
# Pointers to matrices
|
98 |
+
ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr,
|
99 |
+
ddt_ptr, dA_ptr, ddt_bias_ptr,
|
100 |
+
# Matrix dimensions
|
101 |
+
batch, seqlen, nheads, chunk_size,
|
102 |
+
dt_min, dt_max,
|
103 |
+
# Strides
|
104 |
+
stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize,
|
105 |
+
stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize,
|
106 |
+
stride_dt_batch, stride_dt_seqlen, stride_dt_head,
|
107 |
+
stride_A_head,
|
108 |
+
stride_dt_bias_head,
|
109 |
+
stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head,
|
110 |
+
stride_dA_head,
|
111 |
+
stride_ddt_bias_head,
|
112 |
+
# Meta-parameters
|
113 |
+
DT_SOFTPLUS: tl.constexpr,
|
114 |
+
HAS_DT_BIAS: tl.constexpr,
|
115 |
+
BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
|
116 |
+
):
|
117 |
+
pid_b = tl.program_id(axis=0)
|
118 |
+
pid_c = tl.program_id(axis=1)
|
119 |
+
pid_h = tl.program_id(axis=2)
|
120 |
+
ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
|
121 |
+
ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
|
122 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
123 |
+
ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
|
124 |
+
|
125 |
+
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
126 |
+
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
127 |
+
ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize)
|
128 |
+
ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize)
|
129 |
+
dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
|
130 |
+
ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen)
|
131 |
+
A_ptrs = A_ptr + offs_h * stride_A_head
|
132 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
133 |
+
|
134 |
+
ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
|
135 |
+
ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
|
136 |
+
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
137 |
+
ddt = ddA * A[:, None] + ddt_out
|
138 |
+
dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
|
139 |
+
if HAS_DT_BIAS:
|
140 |
+
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
141 |
+
dt += dt_bias[:, None]
|
142 |
+
if DT_SOFTPLUS:
|
143 |
+
dt_presoftplus = dt
|
144 |
+
dt = softplus(dt)
|
145 |
+
clamp_mask = (dt < dt_min) | (dt > dt_max)
|
146 |
+
# As of Triton 2.2.0, tl.clamp is not available yet
|
147 |
+
# dt = tl.clamp(dt, dt_min, dt_max)
|
148 |
+
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
149 |
+
dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
|
150 |
+
ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0)
|
151 |
+
ddt = tl.where(clamp_mask, 0.0, ddt)
|
152 |
+
if DT_SOFTPLUS:
|
153 |
+
ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
|
154 |
+
tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit))
|
155 |
+
dA = tl.sum(ddA * dt, axis=1)
|
156 |
+
tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
|
157 |
+
if HAS_DT_BIAS:
|
158 |
+
ddt_bias = tl.sum(ddt, axis=1)
|
159 |
+
tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads)
|
160 |
+
|
161 |
+
|
162 |
+
@triton.autotune(
|
163 |
+
configs=[
|
164 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
|
165 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
166 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
167 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
168 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
169 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
170 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
171 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
172 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
|
173 |
+
],
|
174 |
+
key=['hdim', 'dstate', 'chunk_size'],
|
175 |
+
)
|
176 |
+
@triton.jit
|
177 |
+
def _chunk_state_fwd_kernel(
|
178 |
+
# Pointers to matrices
|
179 |
+
x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
|
180 |
+
# Matrix dimensions
|
181 |
+
hdim, dstate, chunk_size,
|
182 |
+
batch, seqlen, nheads_ngroups_ratio,
|
183 |
+
# Strides
|
184 |
+
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
|
185 |
+
stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
|
186 |
+
stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
|
187 |
+
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
|
188 |
+
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
189 |
+
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
190 |
+
# Meta-parameters
|
191 |
+
HAS_SEQ_IDX: tl.constexpr,
|
192 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
193 |
+
):
|
194 |
+
pid_bc = tl.program_id(axis=1)
|
195 |
+
pid_c = pid_bc // batch
|
196 |
+
pid_b = pid_bc - pid_c * batch
|
197 |
+
pid_h = tl.program_id(axis=2)
|
198 |
+
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
199 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
200 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
201 |
+
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
|
202 |
+
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
203 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
204 |
+
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
205 |
+
if HAS_SEQ_IDX:
|
206 |
+
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
207 |
+
|
208 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
209 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
210 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
211 |
+
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
|
212 |
+
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
|
213 |
+
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
214 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
215 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
216 |
+
if HAS_SEQ_IDX:
|
217 |
+
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
218 |
+
|
219 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
220 |
+
if HAS_SEQ_IDX:
|
221 |
+
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
222 |
+
|
223 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
224 |
+
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
225 |
+
x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0)
|
226 |
+
b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
|
227 |
+
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
|
228 |
+
if HAS_SEQ_IDX:
|
229 |
+
seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
|
230 |
+
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
|
231 |
+
if not HAS_SEQ_IDX:
|
232 |
+
scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
|
233 |
+
else:
|
234 |
+
scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
|
235 |
+
b *= scale[:, None]
|
236 |
+
b = b.to(x_ptr.dtype.element_ty)
|
237 |
+
acc += tl.dot(x, b)
|
238 |
+
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
239 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
240 |
+
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
241 |
+
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
242 |
+
if HAS_SEQ_IDX:
|
243 |
+
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
244 |
+
states = acc.to(states_ptr.dtype.element_ty)
|
245 |
+
|
246 |
+
states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
|
247 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
248 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
249 |
+
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
|
250 |
+
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
251 |
+
tl.store(states_ptrs, states, mask=c_mask)
|
252 |
+
|
253 |
+
|
254 |
+
@triton.autotune(
|
255 |
+
configs=[
|
256 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
257 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
258 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
259 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
260 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
261 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
262 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
263 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
264 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
265 |
+
],
|
266 |
+
key=['chunk_size', 'hdim', 'dstate'],
|
267 |
+
)
|
268 |
+
@triton.jit
|
269 |
+
def _chunk_state_bwd_dx_kernel(
|
270 |
+
# Pointers to matrices
|
271 |
+
x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr,
|
272 |
+
dx_ptr, ddt_ptr, ddA_cumsum_ptr,
|
273 |
+
# Matrix dimensions
|
274 |
+
chunk_size, hdim, dstate,
|
275 |
+
batch, seqlen, nheads_ngroups_ratio,
|
276 |
+
# Strides
|
277 |
+
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
|
278 |
+
stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
|
279 |
+
stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
|
280 |
+
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
|
281 |
+
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
282 |
+
stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
|
283 |
+
stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
|
284 |
+
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
|
285 |
+
# Meta-parameters
|
286 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
287 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
288 |
+
):
|
289 |
+
pid_bc = tl.program_id(axis=1)
|
290 |
+
pid_c = pid_bc // batch
|
291 |
+
pid_b = pid_bc - pid_c * batch
|
292 |
+
pid_h = tl.program_id(axis=2)
|
293 |
+
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
294 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
295 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
296 |
+
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
297 |
+
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
|
298 |
+
dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head
|
299 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
300 |
+
ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
301 |
+
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
|
302 |
+
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
303 |
+
|
304 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
305 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
306 |
+
|
307 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
308 |
+
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
309 |
+
offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
310 |
+
b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)
|
311 |
+
dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)
|
312 |
+
if BLOCK_SIZE_DSTATE <= 128:
|
313 |
+
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)
|
314 |
+
dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
|
315 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
316 |
+
acc = tl.dot(b, dstates)
|
317 |
+
else:
|
318 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
319 |
+
for k in range(0, dstate, BLOCK_SIZE_K):
|
320 |
+
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)
|
321 |
+
dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
|
322 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
323 |
+
acc += tl.dot(b, dstates)
|
324 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
325 |
+
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
326 |
+
|
327 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
328 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
329 |
+
|
330 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
331 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
332 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
333 |
+
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
334 |
+
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
335 |
+
acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
|
336 |
+
|
337 |
+
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
338 |
+
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
339 |
+
ddt = tl.sum(acc * x, axis=1)
|
340 |
+
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
341 |
+
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
342 |
+
ddA_cs = -(ddt * dt_m)
|
343 |
+
ddA_cs_last = -tl.sum(ddA_cs)
|
344 |
+
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
345 |
+
tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
346 |
+
tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
|
347 |
+
|
348 |
+
dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
|
349 |
+
dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
|
350 |
+
dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
|
351 |
+
tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
|
352 |
+
|
353 |
+
|
354 |
+
@triton.autotune(
|
355 |
+
configs=[
|
356 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
357 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
358 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
359 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
360 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
361 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
362 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
363 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
364 |
+
],
|
365 |
+
key=['chunk_size', 'dstate', 'hdim'],
|
366 |
+
)
|
367 |
+
@triton.jit
|
368 |
+
def _chunk_state_bwd_db_kernel(
|
369 |
+
# Pointers to matrices
|
370 |
+
x_ptr, dstates_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
|
371 |
+
db_ptr, ddA_cumsum_ptr,
|
372 |
+
# Matrix dimensions
|
373 |
+
chunk_size, dstate, hdim,
|
374 |
+
batch, seqlen, nheads, nheads_per_program, ngroups,
|
375 |
+
# Strides
|
376 |
+
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
|
377 |
+
stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
|
378 |
+
stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
|
379 |
+
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
|
380 |
+
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
381 |
+
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
382 |
+
stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate,
|
383 |
+
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
|
384 |
+
# Meta-parameters
|
385 |
+
HAS_DDA_CS: tl.constexpr,
|
386 |
+
HAS_SEQ_IDX: tl.constexpr,
|
387 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
388 |
+
):
|
389 |
+
pid_bc = tl.program_id(axis=1)
|
390 |
+
pid_c = pid_bc // batch
|
391 |
+
pid_b = pid_bc - pid_c * batch
|
392 |
+
pid_sg = tl.program_id(axis=2)
|
393 |
+
pid_s = pid_sg // ngroups
|
394 |
+
pid_g = pid_sg - pid_s * ngroups
|
395 |
+
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
396 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
397 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
398 |
+
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
|
399 |
+
db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split
|
400 |
+
dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head
|
401 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
|
402 |
+
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
|
403 |
+
if HAS_DDA_CS:
|
404 |
+
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head
|
405 |
+
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head
|
406 |
+
if HAS_SEQ_IDX:
|
407 |
+
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
408 |
+
|
409 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
410 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
411 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
412 |
+
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim)
|
413 |
+
dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim)
|
414 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
415 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
416 |
+
if HAS_DDA_CS:
|
417 |
+
b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate)
|
418 |
+
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
419 |
+
|
420 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
421 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
422 |
+
if HAS_DDA_CS:
|
423 |
+
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
|
424 |
+
if HAS_SEQ_IDX:
|
425 |
+
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
426 |
+
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
427 |
+
nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
|
428 |
+
for h in range(nheads_iter):
|
429 |
+
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
|
430 |
+
dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
|
431 |
+
dstates = dstates.to(x_ptrs.dtype.element_ty)
|
432 |
+
db = tl.dot(x, dstates)
|
433 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
434 |
+
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
435 |
+
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
436 |
+
if not HAS_SEQ_IDX:
|
437 |
+
scale = tl.exp(dA_cs_last - dA_cs_m)
|
438 |
+
else:
|
439 |
+
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
440 |
+
db *= (scale * dt_m)[:, None]
|
441 |
+
if HAS_DDA_CS:
|
442 |
+
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
|
443 |
+
ddA_cs = tl.sum(db * b, axis=1)
|
444 |
+
tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
|
445 |
+
acc += db
|
446 |
+
x_ptrs += stride_x_head
|
447 |
+
dstates_ptrs += stride_states_head
|
448 |
+
dt_ptrs += stride_dt_head
|
449 |
+
dA_cumsum_ptr += stride_dA_cs_head
|
450 |
+
dA_cumsum_ptrs += stride_dA_cs_head
|
451 |
+
if HAS_DDA_CS:
|
452 |
+
ddA_cumsum_ptrs += stride_ddA_cs_head
|
453 |
+
|
454 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
455 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
456 |
+
# if HAS_SEQ_IDX:
|
457 |
+
# seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
458 |
+
# seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
459 |
+
# acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
|
460 |
+
db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate)
|
461 |
+
tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate))
|
462 |
+
|
463 |
+
|
464 |
+
@triton.autotune(
|
465 |
+
configs=[
|
466 |
+
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
467 |
+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
468 |
+
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
469 |
+
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
470 |
+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
471 |
+
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
472 |
+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
473 |
+
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
474 |
+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
475 |
+
triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
476 |
+
triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
477 |
+
triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
478 |
+
triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
479 |
+
triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
480 |
+
triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
481 |
+
triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
482 |
+
triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
483 |
+
],
|
484 |
+
key=['chunk_size', 'hdim', 'dstate'],
|
485 |
+
)
|
486 |
+
@triton.jit
|
487 |
+
def _chunk_state_bwd_ddAcs_stable_kernel(
|
488 |
+
# Pointers to matrices
|
489 |
+
x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
|
490 |
+
ddA_cumsum_ptr,
|
491 |
+
# Matrix dimensions
|
492 |
+
chunk_size, hdim, dstate,
|
493 |
+
batch, seqlen, nheads_ngroups_ratio,
|
494 |
+
# Strides
|
495 |
+
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
|
496 |
+
stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
|
497 |
+
stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
|
498 |
+
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
|
499 |
+
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
500 |
+
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
501 |
+
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
|
502 |
+
# Meta-parameters
|
503 |
+
HAS_SEQ_IDX: tl.constexpr,
|
504 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
505 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
506 |
+
):
|
507 |
+
pid_bc = tl.program_id(axis=1)
|
508 |
+
pid_c = pid_bc // batch
|
509 |
+
pid_b = pid_bc - pid_c * batch
|
510 |
+
pid_h = tl.program_id(axis=2)
|
511 |
+
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
512 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
513 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
514 |
+
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
515 |
+
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
|
516 |
+
dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head
|
517 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
518 |
+
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
|
519 |
+
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
520 |
+
if HAS_SEQ_IDX:
|
521 |
+
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
522 |
+
|
523 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
524 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
525 |
+
|
526 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
527 |
+
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
528 |
+
offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
529 |
+
b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)
|
530 |
+
dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)
|
531 |
+
if BLOCK_SIZE_DSTATE <= 128:
|
532 |
+
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)
|
533 |
+
dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
|
534 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
535 |
+
acc = tl.dot(b, dstates)
|
536 |
+
else:
|
537 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
538 |
+
for k in range(0, dstate, BLOCK_SIZE_K):
|
539 |
+
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)
|
540 |
+
dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
|
541 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
542 |
+
acc += tl.dot(b, dstates)
|
543 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
544 |
+
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
545 |
+
|
546 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
547 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
548 |
+
|
549 |
+
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
550 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
551 |
+
if not HAS_SEQ_IDX:
|
552 |
+
scale = tl.exp(dA_cs_last - dA_cs_m)
|
553 |
+
else:
|
554 |
+
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
555 |
+
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
556 |
+
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
557 |
+
acc *= scale[:, None]
|
558 |
+
|
559 |
+
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
560 |
+
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
561 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
562 |
+
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
563 |
+
ddt = tl.sum(acc * x, axis=1)
|
564 |
+
# ddA_cs = -(ddt * dt_m)
|
565 |
+
# Triton 2.2.0 errors if we have the cumsum here, so we just write it out
|
566 |
+
# then call torch.cumsum outside this kernel.
|
567 |
+
# ddA_cs = tl.cumsum(ddt * dt_m)
|
568 |
+
ddA_cs = ddt * dt_m
|
569 |
+
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
570 |
+
# tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
571 |
+
tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
|
572 |
+
|
573 |
+
|
574 |
+
@triton.autotune(
|
575 |
+
configs=[
|
576 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
|
577 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
578 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
579 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
580 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
581 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
582 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
583 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
584 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
|
585 |
+
],
|
586 |
+
key=['hdim', 'dstate', 'chunk_size'],
|
587 |
+
)
|
588 |
+
@triton.jit
|
589 |
+
def _chunk_state_varlen_kernel(
|
590 |
+
# Pointers to matrices
|
591 |
+
x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr,
|
592 |
+
# Matrix dimensions
|
593 |
+
hdim, dstate, chunk_size,
|
594 |
+
seqlen, nheads_ngroups_ratio,
|
595 |
+
# Strides
|
596 |
+
stride_x_seqlen, stride_x_head, stride_x_hdim,
|
597 |
+
stride_b_seqlen, stride_b_head, stride_b_dstate,
|
598 |
+
stride_dt_chunk, stride_dt_head, stride_dt_csize,
|
599 |
+
stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
600 |
+
stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate,
|
601 |
+
stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate,
|
602 |
+
# Meta-parameters
|
603 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
604 |
+
):
|
605 |
+
pid_b = tl.program_id(axis=1)
|
606 |
+
pid_h = tl.program_id(axis=2)
|
607 |
+
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
608 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
609 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
610 |
+
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
611 |
+
pid_c = (end_idx - 1) // chunk_size
|
612 |
+
b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
|
613 |
+
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
614 |
+
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
615 |
+
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
616 |
+
chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
617 |
+
|
618 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
619 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
620 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
621 |
+
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
|
622 |
+
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
|
623 |
+
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
624 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
625 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
626 |
+
|
627 |
+
chunk_size_limit = end_idx - pid_c * chunk_size
|
628 |
+
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
629 |
+
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
630 |
+
|
631 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
632 |
+
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
633 |
+
x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0)
|
634 |
+
b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32)
|
635 |
+
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
|
636 |
+
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
|
637 |
+
scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
638 |
+
tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
|
639 |
+
b *= scale[:, None]
|
640 |
+
b = b.to(x_ptr.dtype.element_ty)
|
641 |
+
acc += tl.dot(x, b)
|
642 |
+
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
643 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
644 |
+
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
645 |
+
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
646 |
+
|
647 |
+
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
648 |
+
if start_idx < pid_c * chunk_size:
|
649 |
+
chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate)
|
650 |
+
chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
|
651 |
+
# scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
|
652 |
+
scale = tl.exp(dA_cs_last)
|
653 |
+
acc += chunk_states * scale
|
654 |
+
|
655 |
+
states = acc.to(states_ptr.dtype.element_ty)
|
656 |
+
|
657 |
+
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
658 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
659 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
660 |
+
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
|
661 |
+
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
662 |
+
tl.store(states_ptrs, states, mask=c_mask)
|
663 |
+
|
664 |
+
|
665 |
+
def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
|
666 |
+
batch, seqlen, nheads = dt.shape
|
667 |
+
assert A.shape == (nheads,)
|
668 |
+
if dt_bias is not None:
|
669 |
+
assert dt_bias.shape == (nheads,)
|
670 |
+
nchunks = math.ceil(seqlen / chunk_size)
|
671 |
+
dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
|
672 |
+
dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
|
673 |
+
grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
|
674 |
+
with torch.cuda.device(dt.device.index):
|
675 |
+
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
676 |
+
dt, A, dt_bias, dt_out, dA_cumsum,
|
677 |
+
batch, seqlen, nheads, chunk_size,
|
678 |
+
dt_limit[0], dt_limit[1],
|
679 |
+
dt.stride(0), dt.stride(1), dt.stride(2),
|
680 |
+
A.stride(0),
|
681 |
+
dt_bias.stride(0) if dt_bias is not None else 0,
|
682 |
+
dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),
|
683 |
+
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
|
684 |
+
dt_softplus,
|
685 |
+
HAS_DT_BIAS=dt_bias is not None,
|
686 |
+
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
687 |
+
)
|
688 |
+
return dA_cumsum, dt_out
|
689 |
+
|
690 |
+
|
691 |
+
def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), ddt=None):
|
692 |
+
batch, seqlen, nheads = dt.shape
|
693 |
+
_, _, nchunks, chunk_size = ddA.shape
|
694 |
+
assert ddA.shape == (batch, nheads, nchunks, chunk_size)
|
695 |
+
assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
|
696 |
+
assert A.shape == (nheads,)
|
697 |
+
if dt_bias is not None:
|
698 |
+
assert dt_bias.shape == (nheads,)
|
699 |
+
ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
|
700 |
+
else:
|
701 |
+
ddt_bias = None
|
702 |
+
if ddt is not None:
|
703 |
+
assert ddt.shape == dt.shape
|
704 |
+
else:
|
705 |
+
ddt = torch.empty_like(dt)
|
706 |
+
dA = torch.empty_like(A, dtype=torch.float32)
|
707 |
+
grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
|
708 |
+
with torch.cuda.device(dt.device.index):
|
709 |
+
_chunk_cumsum_bwd_kernel[grid_chunk_cs](
|
710 |
+
ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias,
|
711 |
+
batch, seqlen, nheads, chunk_size,
|
712 |
+
dt_limit[0], dt_limit[1],
|
713 |
+
ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3),
|
714 |
+
ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3),
|
715 |
+
dt.stride(0), dt.stride(1), dt.stride(2),
|
716 |
+
A.stride(0),
|
717 |
+
dt_bias.stride(0) if dt_bias is not None else 0,
|
718 |
+
ddt.stride(0), ddt.stride(1), ddt.stride(2),
|
719 |
+
dA.stride(0),
|
720 |
+
ddt_bias.stride(0) if ddt_bias is not None else 0,
|
721 |
+
dt_softplus,
|
722 |
+
HAS_DT_BIAS=dt_bias is not None,
|
723 |
+
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
724 |
+
)
|
725 |
+
return ddt, dA, ddt_bias
|
726 |
+
|
727 |
+
|
728 |
+
def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True):
|
729 |
+
batch, seqlen, nheads, headdim = x.shape
|
730 |
+
_, _, nchunks, chunk_size = dt.shape
|
731 |
+
_, _, ngroups, dstate = B.shape
|
732 |
+
assert nheads % ngroups == 0
|
733 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
734 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
735 |
+
assert dA_cumsum.shape == dt.shape
|
736 |
+
if seq_idx is not None:
|
737 |
+
assert seq_idx.shape == (batch, seqlen)
|
738 |
+
if states is not None:
|
739 |
+
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
740 |
+
else:
|
741 |
+
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
742 |
+
states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype)
|
743 |
+
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
|
744 |
+
batch * nchunks, nheads)
|
745 |
+
with torch.cuda.device(x.device.index):
|
746 |
+
_chunk_state_fwd_kernel[grid](
|
747 |
+
x, B, states, dt, dA_cumsum, seq_idx,
|
748 |
+
headdim, dstate, chunk_size,
|
749 |
+
batch, seqlen, nheads // ngroups,
|
750 |
+
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
|
751 |
+
B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
|
752 |
+
states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
|
753 |
+
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
|
754 |
+
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
|
755 |
+
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
756 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
757 |
+
)
|
758 |
+
return states
|
759 |
+
|
760 |
+
|
761 |
+
def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
|
762 |
+
batch, seqlen, nheads, headdim = x.shape
|
763 |
+
_, _, nchunks, chunk_size = dt.shape
|
764 |
+
_, _, ngroups, dstate = B.shape
|
765 |
+
assert nheads % ngroups == 0
|
766 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
767 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
768 |
+
assert dA_cumsum.shape == dt.shape
|
769 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
770 |
+
if dx is not None:
|
771 |
+
assert dx.shape == x.shape
|
772 |
+
else:
|
773 |
+
dx = torch.empty_like(x)
|
774 |
+
ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
|
775 |
+
ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32)
|
776 |
+
grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
|
777 |
+
batch * nchunks, nheads)
|
778 |
+
with torch.cuda.device(x.device.index):
|
779 |
+
_chunk_state_bwd_dx_kernel[grid_dx](
|
780 |
+
x, B, dstates, dt, dA_cumsum, dx, ddt, ddA_cumsum,
|
781 |
+
chunk_size, headdim, dstate,
|
782 |
+
batch, seqlen, nheads // ngroups,
|
783 |
+
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
|
784 |
+
B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
|
785 |
+
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
|
786 |
+
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
|
787 |
+
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
|
788 |
+
dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
|
789 |
+
ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
|
790 |
+
ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
|
791 |
+
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
792 |
+
)
|
793 |
+
return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
|
794 |
+
|
795 |
+
|
796 |
+
def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
|
797 |
+
batch, seqlen, nheads, headdim = x.shape
|
798 |
+
_, _, nchunks, chunk_size = dt.shape
|
799 |
+
dstate = dstates.shape[-1]
|
800 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
801 |
+
assert dA_cumsum.shape == dt.shape
|
802 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
803 |
+
if seq_idx is not None:
|
804 |
+
assert seq_idx.shape == (batch, seqlen)
|
805 |
+
if B is not None:
|
806 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
807 |
+
B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
|
808 |
+
# Use torch.empty since the Triton kernel will call init_to_zero
|
809 |
+
ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
|
810 |
+
ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3))
|
811 |
+
else:
|
812 |
+
B_strides = (0, 0, 0, 0)
|
813 |
+
ddA_cumsum = None
|
814 |
+
ddA_cumsum_strides = (0, 0, 0, 0)
|
815 |
+
nheads_ngroups_ratio = nheads // ngroups
|
816 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
817 |
+
nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
|
818 |
+
nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
|
819 |
+
dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32)
|
820 |
+
grid_db = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
|
821 |
+
batch * nchunks, nsplits * ngroups)
|
822 |
+
with torch.cuda.device(x.device.index):
|
823 |
+
_chunk_state_bwd_db_kernel[grid_db](
|
824 |
+
x, dstates, B, dt, dA_cumsum, seq_idx, dB, ddA_cumsum,
|
825 |
+
chunk_size, dstate, headdim,
|
826 |
+
batch, seqlen, nheads, nheads_per_program, ngroups,
|
827 |
+
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
|
828 |
+
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
|
829 |
+
*B_strides,
|
830 |
+
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
|
831 |
+
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
|
832 |
+
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
833 |
+
dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4),
|
834 |
+
*ddA_cumsum_strides,
|
835 |
+
HAS_DDA_CS=ddA_cumsum is not None,
|
836 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
837 |
+
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
|
838 |
+
)
|
839 |
+
dB = dB.sum(2)
|
840 |
+
if ddA_cumsum is not None:
|
841 |
+
# The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
|
842 |
+
# to the state of the chunk.
|
843 |
+
# torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
844 |
+
# But it's easier to just do the cumsum for all elements, the result will be the same.
|
845 |
+
torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
|
846 |
+
return dB if B is None else (dB, ddA_cumsum)
|
847 |
+
|
848 |
+
|
849 |
+
def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
|
850 |
+
batch, seqlen, nheads, headdim = x.shape
|
851 |
+
_, _, nchunks, chunk_size = dt.shape
|
852 |
+
_, _, ngroups, dstate = B.shape
|
853 |
+
assert nheads % ngroups == 0
|
854 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
855 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
856 |
+
assert dA_cumsum.shape == dt.shape
|
857 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
858 |
+
if seq_idx is not None:
|
859 |
+
assert seq_idx.shape == (batch, seqlen)
|
860 |
+
# Use torch.empty since the Triton kernel will call init_to_zero
|
861 |
+
ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
|
862 |
+
grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
|
863 |
+
batch * nchunks, nheads)
|
864 |
+
with torch.cuda.device(x.device.index):
|
865 |
+
_chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
|
866 |
+
x, B, dstates, dt, dA_cumsum, seq_idx, ddA_cumsum,
|
867 |
+
chunk_size, headdim, dstate,
|
868 |
+
batch, seqlen, nheads // ngroups,
|
869 |
+
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
|
870 |
+
B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
|
871 |
+
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
|
872 |
+
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
|
873 |
+
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
|
874 |
+
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
875 |
+
ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
|
876 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
877 |
+
BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
|
878 |
+
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
879 |
+
)
|
880 |
+
torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
881 |
+
return ddA_cumsum
|
882 |
+
|
883 |
+
|
884 |
+
def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
|
885 |
+
total_seqlen, nheads, headdim = x.shape
|
886 |
+
_, nchunks, chunk_size = dt.shape
|
887 |
+
_, ngroups, dstate = B.shape
|
888 |
+
batch = cu_seqlens.shape[0] - 1
|
889 |
+
cu_seqlens = cu_seqlens.contiguous()
|
890 |
+
assert nheads % ngroups == 0
|
891 |
+
assert B.shape == (total_seqlen, ngroups, dstate)
|
892 |
+
assert dt.shape == (nheads, nchunks, chunk_size)
|
893 |
+
assert dA_cumsum.shape == dt.shape
|
894 |
+
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
895 |
+
states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device)
|
896 |
+
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
|
897 |
+
batch, nheads)
|
898 |
+
with torch.cuda.device(x.device.index):
|
899 |
+
_chunk_state_varlen_kernel[grid](
|
900 |
+
x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states,
|
901 |
+
headdim, dstate, chunk_size,
|
902 |
+
total_seqlen, nheads // ngroups,
|
903 |
+
x.stride(0), x.stride(1), x.stride(2),
|
904 |
+
B.stride(0), B.stride(1), B.stride(2),
|
905 |
+
dt.stride(1), dt.stride(0), dt.stride(2),
|
906 |
+
dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2),
|
907 |
+
chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3),
|
908 |
+
states.stride(0), states.stride(1), states.stride(2), states.stride(3),
|
909 |
+
)
|
910 |
+
return states
|
911 |
+
|
912 |
+
|
913 |
+
class ChunkStateFn(torch.autograd.Function):
|
914 |
+
|
915 |
+
@staticmethod
|
916 |
+
def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
|
917 |
+
batch, seqlen, nheads, headdim = x.shape
|
918 |
+
_, _, nchunks, chunk_size = dt.shape
|
919 |
+
assert seqlen <= nchunks * chunk_size
|
920 |
+
_, _, ngroups, dstate = B.shape
|
921 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
922 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
923 |
+
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
924 |
+
if B.stride(-1) != 1:
|
925 |
+
B = B.contiguous()
|
926 |
+
if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
|
927 |
+
x = x.contiguous()
|
928 |
+
states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
|
929 |
+
ctx.save_for_backward(B, x, dt, dA_cumsum)
|
930 |
+
return states
|
931 |
+
|
932 |
+
@staticmethod
|
933 |
+
def backward(ctx, dstates):
|
934 |
+
B, x, dt, dA_cumsum = ctx.saved_tensors
|
935 |
+
batch, seqlen, nheads, headdim = x.shape
|
936 |
+
_, _, nchunks, chunk_size = dt.shape
|
937 |
+
_, _, ngroups, dstate = B.shape
|
938 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
939 |
+
if dstates.stride(-1) != 1:
|
940 |
+
dstates = dstates.contiguous()
|
941 |
+
dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
|
942 |
+
dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
|
943 |
+
dB = dB.to(B.dtype)
|
944 |
+
return dB, dx, ddt, ddA_cumsum, None
|
945 |
+
|
946 |
+
|
947 |
+
def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
|
948 |
+
"""
|
949 |
+
Argument:
|
950 |
+
B: (batch, seqlen, ngroups, headdim)
|
951 |
+
x: (batch, seqlen, nheads, headdim)
|
952 |
+
dt: (batch, nheads, nchunks, chunk_size)
|
953 |
+
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
954 |
+
Return:
|
955 |
+
states: (batch, nchunks, nheads, headdim, dstate)
|
956 |
+
"""
|
957 |
+
return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
|
958 |
+
|
959 |
+
|
960 |
+
def chunk_state_ref(B, x, dt, dA_cumsum):
|
961 |
+
"""
|
962 |
+
Argument:
|
963 |
+
B: (batch, seqlen, ngroups, headdim)
|
964 |
+
x: (batch, seqlen, nheads, headdim)
|
965 |
+
dt: (batch, nheads, nchunks, chunk_size)
|
966 |
+
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
967 |
+
Return:
|
968 |
+
states: (batch, nchunks, nheads, headdim, dstate)
|
969 |
+
"""
|
970 |
+
# Check constraints.
|
971 |
+
batch, seqlen, nheads, headdim = x.shape
|
972 |
+
dstate = B.shape[-1]
|
973 |
+
_, _, nchunks, chunk_size = dt.shape
|
974 |
+
assert seqlen <= nchunks * chunk_size
|
975 |
+
assert x.shape == (batch, seqlen, nheads, headdim)
|
976 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
977 |
+
ngroups = B.shape[2]
|
978 |
+
assert nheads % ngroups == 0
|
979 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
980 |
+
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
|
981 |
+
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
982 |
+
if seqlen < nchunks * chunk_size:
|
983 |
+
x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
984 |
+
B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
985 |
+
x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
|
986 |
+
B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
|
987 |
+
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
|
988 |
+
return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x)
|
mamba/build/lib/mamba_ssm/ops/triton/ssd_combined.py
ADDED
@@ -0,0 +1,981 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or 2.2.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import math
|
9 |
+
from packaging import version
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import Tensor
|
14 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
15 |
+
|
16 |
+
import triton
|
17 |
+
import triton.language as tl
|
18 |
+
|
19 |
+
from einops import rearrange, repeat
|
20 |
+
|
21 |
+
try:
|
22 |
+
from causal_conv1d import causal_conv1d_fn
|
23 |
+
import causal_conv1d_cuda
|
24 |
+
except ImportError:
|
25 |
+
causal_conv1d_fn, causal_conv1d_cuda = None, None
|
26 |
+
|
27 |
+
from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
|
28 |
+
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
|
29 |
+
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
|
30 |
+
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
|
31 |
+
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
|
32 |
+
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen
|
33 |
+
from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd
|
34 |
+
from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref
|
35 |
+
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
|
36 |
+
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
|
37 |
+
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
|
38 |
+
from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
|
39 |
+
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
|
40 |
+
from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
|
41 |
+
from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd
|
42 |
+
|
43 |
+
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
44 |
+
|
45 |
+
|
46 |
+
def init_to_zero(names):
|
47 |
+
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
|
48 |
+
|
49 |
+
|
50 |
+
@triton.autotune(
|
51 |
+
configs=[
|
52 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
|
53 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
54 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
55 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
56 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
57 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
58 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
59 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
60 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
61 |
+
],
|
62 |
+
key=['chunk_size', 'hdim', 'dstate'],
|
63 |
+
)
|
64 |
+
@triton.jit
|
65 |
+
def _chunk_scan_chunk_state_bwd_dx_kernel(
|
66 |
+
# Pointers to matrices
|
67 |
+
x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,
|
68 |
+
b_ptr, dstates_ptr,
|
69 |
+
dx_ptr, ddt_ptr, dD_ptr,
|
70 |
+
# Matrix dimensions
|
71 |
+
chunk_size, hdim, dstate,
|
72 |
+
batch, seqlen, nheads_ngroups_ratio,
|
73 |
+
# Strides
|
74 |
+
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
|
75 |
+
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
|
76 |
+
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
|
77 |
+
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
|
78 |
+
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
79 |
+
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
80 |
+
stride_D_head,
|
81 |
+
stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
|
82 |
+
stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,
|
83 |
+
stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
|
84 |
+
stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
|
85 |
+
stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
|
86 |
+
# Meta-parameters
|
87 |
+
HAS_D: tl.constexpr,
|
88 |
+
D_HAS_HDIM: tl.constexpr,
|
89 |
+
HAS_SEQ_IDX: tl.constexpr,
|
90 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
91 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
92 |
+
IS_TRITON_22: tl.constexpr,
|
93 |
+
):
|
94 |
+
pid_bc = tl.program_id(axis=1)
|
95 |
+
pid_c = pid_bc // batch
|
96 |
+
pid_b = pid_bc - pid_c * batch
|
97 |
+
pid_h = tl.program_id(axis=2)
|
98 |
+
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
99 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
100 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
101 |
+
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
102 |
+
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
|
103 |
+
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
|
104 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
105 |
+
ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
106 |
+
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
107 |
+
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
|
108 |
+
dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head
|
109 |
+
if HAS_SEQ_IDX:
|
110 |
+
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
111 |
+
|
112 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
113 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
114 |
+
|
115 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
116 |
+
|
117 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
118 |
+
|
119 |
+
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
120 |
+
|
121 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
122 |
+
if not HAS_SEQ_IDX:
|
123 |
+
scale = tl.exp(dA_cs_last - dA_cs_m)
|
124 |
+
else:
|
125 |
+
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
126 |
+
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
127 |
+
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
128 |
+
# Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
129 |
+
# However, we're getting error with the Triton compiler 2.1.0 for that code path:
|
130 |
+
# Unexpected mma -> mma layout conversion
|
131 |
+
# Triton 2.2.0 fixes this
|
132 |
+
offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
133 |
+
b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)
|
134 |
+
dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)
|
135 |
+
if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
|
136 |
+
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)
|
137 |
+
dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
|
138 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
139 |
+
acc = tl.dot(b, dstates) * scale[:, None]
|
140 |
+
else:
|
141 |
+
for k in range(0, dstate, BLOCK_SIZE_K):
|
142 |
+
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)
|
143 |
+
dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
|
144 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
145 |
+
acc += tl.dot(b, dstates)
|
146 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
147 |
+
dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
|
148 |
+
acc *= scale[:, None]
|
149 |
+
|
150 |
+
# x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
151 |
+
# x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
152 |
+
# dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
153 |
+
# dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
154 |
+
# ddt = tl.sum(acc * x, axis=1) * dt_m
|
155 |
+
# ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
156 |
+
# tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
157 |
+
|
158 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
159 |
+
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
|
160 |
+
dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
|
161 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
162 |
+
K_MAX = chunk_size_limit
|
163 |
+
K_MIN = pid_m * BLOCK_SIZE_M
|
164 |
+
cb_ptrs += K_MIN * stride_cb_csize_k
|
165 |
+
dout_ptrs += K_MIN * stride_dout_seqlen
|
166 |
+
dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
|
167 |
+
for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
|
168 |
+
k = tl.multiple_of(k, BLOCK_SIZE_K)
|
169 |
+
# For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
|
170 |
+
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
|
171 |
+
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
|
172 |
+
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
|
173 |
+
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
|
174 |
+
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
|
175 |
+
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
|
176 |
+
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
|
177 |
+
# This will cause NaN in acc, and hence NaN in dx and ddt.
|
178 |
+
mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
|
179 |
+
cb = tl.where(mask, cb, 0.0)
|
180 |
+
cb = cb.to(dout_ptr.dtype.element_ty)
|
181 |
+
acc += tl.dot(cb, dout)
|
182 |
+
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
183 |
+
dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
|
184 |
+
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
185 |
+
|
186 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
187 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
188 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
189 |
+
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
190 |
+
dx = acc * dt_m[:, None]
|
191 |
+
dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
|
192 |
+
dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
|
193 |
+
if HAS_D:
|
194 |
+
dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
|
195 |
+
dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
196 |
+
if D_HAS_HDIM:
|
197 |
+
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
|
198 |
+
else:
|
199 |
+
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
200 |
+
dx += dout_res * D
|
201 |
+
tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
|
202 |
+
|
203 |
+
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
204 |
+
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
205 |
+
if HAS_D:
|
206 |
+
dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
|
207 |
+
if D_HAS_HDIM:
|
208 |
+
dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
|
209 |
+
dD = tl.sum(dout_res * x, axis=0)
|
210 |
+
tl.store(dD_ptrs, dD, mask=offs_n < hdim)
|
211 |
+
else:
|
212 |
+
dD = tl.sum(dout_res * x)
|
213 |
+
tl.store(dD_ptr, dD)
|
214 |
+
ddt = tl.sum(acc * x, axis=1)
|
215 |
+
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
216 |
+
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
217 |
+
|
218 |
+
|
219 |
+
def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):
|
220 |
+
batch, seqlen, nheads, headdim = x.shape
|
221 |
+
_, _, nchunks, chunk_size = dt.shape
|
222 |
+
_, _, ngroups, dstate = B.shape
|
223 |
+
assert nheads % ngroups == 0
|
224 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
225 |
+
assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
|
226 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
227 |
+
assert dA_cumsum.shape == dt.shape
|
228 |
+
assert dout.shape == x.shape
|
229 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
230 |
+
if seq_idx is not None:
|
231 |
+
assert seq_idx.shape == (batch, seqlen)
|
232 |
+
if D is not None:
|
233 |
+
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
234 |
+
assert D.stride(-1) == 1
|
235 |
+
BLOCK_SIZE_min = 32
|
236 |
+
dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
|
237 |
+
headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
|
238 |
+
else:
|
239 |
+
dD = None
|
240 |
+
dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
|
241 |
+
if D is not None else (0, 0, 0, 0, 0))
|
242 |
+
if dx is None:
|
243 |
+
dx = torch.empty_like(x)
|
244 |
+
else:
|
245 |
+
assert dx.shape == x.shape
|
246 |
+
ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
|
247 |
+
grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
|
248 |
+
batch * nchunks, nheads)
|
249 |
+
with torch.cuda.device(x.device.index):
|
250 |
+
_chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
|
251 |
+
x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,
|
252 |
+
chunk_size, headdim, dstate,
|
253 |
+
batch, seqlen, nheads // ngroups,
|
254 |
+
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
|
255 |
+
CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),
|
256 |
+
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
|
257 |
+
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
|
258 |
+
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
|
259 |
+
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
260 |
+
D.stride(0) if D is not None else 0,
|
261 |
+
B.stride(0), B.stride(1), B.stride(2), B.stride(3),
|
262 |
+
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
|
263 |
+
dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
|
264 |
+
ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
|
265 |
+
dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
|
266 |
+
D is not None,
|
267 |
+
D.dim() == 2 if D is not None else True,
|
268 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
269 |
+
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
270 |
+
IS_TRITON_22=TRITON_22
|
271 |
+
)
|
272 |
+
if D is not None:
|
273 |
+
BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"]
|
274 |
+
n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
|
275 |
+
dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
|
276 |
+
if D.dim() == 1:
|
277 |
+
dD = rearrange(dD, "h 1 -> h")
|
278 |
+
return dx, ddt.to(dtype=dt.dtype), dD
|
279 |
+
|
280 |
+
|
281 |
+
def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
|
282 |
+
batch, seqlen, nheads, headdim = x.shape
|
283 |
+
_, _, ngroups, dstate = B.shape
|
284 |
+
assert nheads % ngroups == 0
|
285 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
286 |
+
assert x.shape == (batch, seqlen, nheads, headdim)
|
287 |
+
assert dt.shape == (batch, seqlen, nheads)
|
288 |
+
assert A.shape == (nheads,)
|
289 |
+
assert C.shape == B.shape
|
290 |
+
if z is not None:
|
291 |
+
assert z.shape == x.shape
|
292 |
+
if D is not None:
|
293 |
+
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
294 |
+
if seq_idx is not None:
|
295 |
+
assert seq_idx.shape == (batch, seqlen)
|
296 |
+
if B.stride(-1) != 1:
|
297 |
+
B = B.contiguous()
|
298 |
+
if C.stride(-1) != 1:
|
299 |
+
C = C.contiguous()
|
300 |
+
if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
|
301 |
+
x = x.contiguous()
|
302 |
+
if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous
|
303 |
+
z = z.contiguous()
|
304 |
+
if D is not None and D.stride(-1) != 1:
|
305 |
+
D = D.contiguous()
|
306 |
+
if initial_states is not None:
|
307 |
+
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
308 |
+
# # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
|
309 |
+
# dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
310 |
+
# dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
311 |
+
# dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
312 |
+
dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
|
313 |
+
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
314 |
+
# states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
|
315 |
+
# states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
|
316 |
+
# states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
|
317 |
+
states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
|
318 |
+
initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
|
319 |
+
seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype)
|
320 |
+
states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]]
|
321 |
+
# states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
322 |
+
# states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
323 |
+
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
324 |
+
out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx)
|
325 |
+
if cu_seqlens is None:
|
326 |
+
return out, out_x, dt, dA_cumsum, states, final_states
|
327 |
+
else:
|
328 |
+
assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
329 |
+
varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0),
|
330 |
+
cu_seqlens, states.squeeze(0))
|
331 |
+
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
|
332 |
+
|
333 |
+
|
334 |
+
def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None,
|
335 |
+
dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False,
|
336 |
+
dt_limit=(0.0, float("inf")),
|
337 |
+
dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False):
|
338 |
+
if dout.stride(-1) != 1:
|
339 |
+
dout = dout.contiguous()
|
340 |
+
batch, seqlen, nheads, headdim = x.shape
|
341 |
+
nchunks = math.ceil(seqlen / chunk_size)
|
342 |
+
_, _, ngroups, dstate = B.shape
|
343 |
+
assert dout.shape == (batch, seqlen, nheads, headdim)
|
344 |
+
assert dt.shape == (batch, seqlen, nheads)
|
345 |
+
assert A.shape == (nheads,)
|
346 |
+
assert nheads % ngroups == 0
|
347 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
348 |
+
assert C.shape == B.shape
|
349 |
+
assert out.shape == x.shape
|
350 |
+
if initial_states is not None:
|
351 |
+
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
352 |
+
if seq_idx is not None:
|
353 |
+
assert seq_idx.shape == (batch, seqlen)
|
354 |
+
if dx is not None:
|
355 |
+
assert dx.shape == x.shape
|
356 |
+
if dB is not None:
|
357 |
+
assert dB.shape == B.shape
|
358 |
+
dB_given = dB
|
359 |
+
else:
|
360 |
+
dB_given = torch.empty_like(B)
|
361 |
+
if dC is not None:
|
362 |
+
assert dC.shape == C.shape
|
363 |
+
dC_given = dC
|
364 |
+
else:
|
365 |
+
dC_given = torch.empty_like(C)
|
366 |
+
if dz is not None:
|
367 |
+
assert z is not None
|
368 |
+
assert dz.shape == z.shape
|
369 |
+
if ddt is not None:
|
370 |
+
assert ddt.shape == dt.shape
|
371 |
+
ddt_given = ddt
|
372 |
+
else:
|
373 |
+
ddt_given = torch.empty_like(dt)
|
374 |
+
# TD: For some reason Triton (2.1.0 and 2.2.0) errors with
|
375 |
+
# "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
|
376 |
+
dt_in = dt.clone()
|
377 |
+
dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus,
|
378 |
+
dt_limit=dt_limit)
|
379 |
+
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
380 |
+
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
381 |
+
states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
|
382 |
+
initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
|
383 |
+
seq_idx=seq_idx, chunk_size=chunk_size)
|
384 |
+
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
385 |
+
if z is not None:
|
386 |
+
dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output)
|
387 |
+
outz = rest[0] if recompute_output else out
|
388 |
+
else:
|
389 |
+
dz = None
|
390 |
+
outz = out
|
391 |
+
dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype)
|
392 |
+
# dstates has length nchunks, containing the gradient to initial states at index 0 and
|
393 |
+
# gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
|
394 |
+
# Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
|
395 |
+
# will be used in matmul in the next kernels.
|
396 |
+
dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
|
397 |
+
rearrange(states, "... p n -> ... (p n)"),
|
398 |
+
dA_cumsum[:, :, :, -1],
|
399 |
+
rearrange(dstates, "... p n -> ... (p n)"),
|
400 |
+
dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None,
|
401 |
+
seq_idx=seq_idx,
|
402 |
+
has_initial_states=initial_states is not None,
|
403 |
+
dstates_dtype=x.dtype,
|
404 |
+
states_dtype=x.dtype,
|
405 |
+
chunk_size=chunk_size,
|
406 |
+
)
|
407 |
+
# dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
|
408 |
+
# gradient to the final states at index (nchunks - 1)
|
409 |
+
# states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
|
410 |
+
# The final states is not stored.
|
411 |
+
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
412 |
+
dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
|
413 |
+
dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None
|
414 |
+
dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx)
|
415 |
+
# dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
|
416 |
+
dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups)
|
417 |
+
# dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
418 |
+
dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups)
|
419 |
+
# Computing ddA with the dcb kernel is much slower, so we're not using it for now
|
420 |
+
dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
421 |
+
# dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
|
422 |
+
dCB = dCB.to(CB.dtype)
|
423 |
+
_bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
|
424 |
+
_bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
|
425 |
+
# If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
|
426 |
+
# than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
|
427 |
+
if z is None:
|
428 |
+
dD = dD_from_x
|
429 |
+
# Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
|
430 |
+
# ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
|
431 |
+
# However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
|
432 |
+
# be a lot of underflow.
|
433 |
+
|
434 |
+
# This is already done as part of bwd_dC kernel
|
435 |
+
# ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
|
436 |
+
ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
|
437 |
+
ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
|
438 |
+
# This is already done as part of bwd_dB kernel
|
439 |
+
# ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
|
440 |
+
# We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
|
441 |
+
ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
|
442 |
+
ddA += ddA_next + ddA_prev
|
443 |
+
|
444 |
+
ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given)
|
445 |
+
|
446 |
+
# These 2 lines are just to test ddt and dA being computed by old code
|
447 |
+
# _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
|
448 |
+
# ddt_given.copy_(ddt)
|
449 |
+
|
450 |
+
return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states)
|
451 |
+
return return_vals if not recompute_output else (*return_vals, outz)
|
452 |
+
|
453 |
+
|
454 |
+
def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
|
455 |
+
"""
|
456 |
+
Argument:
|
457 |
+
dout: (batch, seqlen, nheads, headdim)
|
458 |
+
x: (batch, seqlen, nheads, headdim)
|
459 |
+
dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
|
460 |
+
A: (nheads) or (dim, dstate)
|
461 |
+
B: (batch, seqlen, ngroups, dstate)
|
462 |
+
C: (batch, seqlen, ngroups, dstate)
|
463 |
+
D: (nheads, headdim) or (nheads,)
|
464 |
+
z: (batch, seqlen, nheads, headdim)
|
465 |
+
Return:
|
466 |
+
out: (batch, seqlen, nheads, headdim)
|
467 |
+
"""
|
468 |
+
import selective_scan
|
469 |
+
|
470 |
+
batch, seqlen, nheads, headdim = x.shape
|
471 |
+
chunk_size = dt.shape[-1]
|
472 |
+
_, _, ngroups, dstate = B.shape
|
473 |
+
assert nheads % ngroups == 0
|
474 |
+
x = rearrange(x, "b l h p -> b (h p) l")
|
475 |
+
squeeze_dt = dt.dim() == 4
|
476 |
+
if dt.dim() == 4:
|
477 |
+
dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
|
478 |
+
dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
|
479 |
+
squeeze_A = A.dim() == 1
|
480 |
+
if A.dim() == 1:
|
481 |
+
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
482 |
+
else:
|
483 |
+
A = A.to(dtype=torch.float32)
|
484 |
+
B = rearrange(B, "b l g n -> b g n l")
|
485 |
+
C = rearrange(C, "b l g n -> b g n l")
|
486 |
+
if D is not None:
|
487 |
+
if D.dim() == 2:
|
488 |
+
D = rearrange(D, "h p -> (h p)")
|
489 |
+
else:
|
490 |
+
D = repeat(D, "h -> (h p)", p=headdim)
|
491 |
+
if z is not None:
|
492 |
+
z = rearrange(z, "b l h p -> b (h p) l")
|
493 |
+
|
494 |
+
if x.stride(-1) != 1:
|
495 |
+
x = x.contiguous()
|
496 |
+
if dt.stride(-1) != 1:
|
497 |
+
dt = dt.contiguous()
|
498 |
+
if D is not None:
|
499 |
+
D = D.contiguous()
|
500 |
+
if B.stride(-1) != 1:
|
501 |
+
B = B.contiguous()
|
502 |
+
if C.stride(-1) != 1:
|
503 |
+
C = C.contiguous()
|
504 |
+
if z is not None and z.stride(-1) != 1:
|
505 |
+
z = z.contiguous()
|
506 |
+
_, intermediate, *rest = selective_scan.fwd(x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False)
|
507 |
+
if z is not None:
|
508 |
+
out = rest[0]
|
509 |
+
else:
|
510 |
+
out = None
|
511 |
+
|
512 |
+
dout = rearrange(dout, "b l h p -> b (h p) l")
|
513 |
+
|
514 |
+
if dout.stride(-1) != 1:
|
515 |
+
dout = dout.contiguous()
|
516 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
517 |
+
# backward of selective_scan with the backward of chunk).
|
518 |
+
# Here we just pass in None and dz will be allocated in the C++ code.
|
519 |
+
_, ddt, dA, *rest = selective_scan.bwd(
|
520 |
+
x, dt.to(dtype=x.dtype), A, B, C, D, z, None, dout, intermediate, out, None, False,
|
521 |
+
False # option to recompute out_z, not used here
|
522 |
+
)
|
523 |
+
ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
|
524 |
+
if squeeze_dt:
|
525 |
+
ddt = ddt.float().sum(dim=2)
|
526 |
+
if squeeze_A:
|
527 |
+
dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
|
528 |
+
return ddt, dA
|
529 |
+
|
530 |
+
|
531 |
+
class MambaChunkScanCombinedFn(torch.autograd.Function):
|
532 |
+
|
533 |
+
@staticmethod
|
534 |
+
def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False):
|
535 |
+
ctx.dt_dtype = dt.dtype
|
536 |
+
if not return_varlen_states:
|
537 |
+
cu_seqlens = None
|
538 |
+
else:
|
539 |
+
assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
|
540 |
+
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
|
541 |
+
ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx)
|
542 |
+
ctx.dt_softplus = dt_softplus
|
543 |
+
ctx.chunk_size = chunk_size
|
544 |
+
ctx.dt_limit = dt_limit
|
545 |
+
ctx.return_final_states = return_final_states
|
546 |
+
ctx.return_varlen_states = return_varlen_states
|
547 |
+
if not return_varlen_states:
|
548 |
+
return out if not return_final_states else (out, final_states)
|
549 |
+
else:
|
550 |
+
varlen_states = rest[0]
|
551 |
+
return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states)
|
552 |
+
|
553 |
+
@staticmethod
|
554 |
+
def backward(ctx, dout, *args):
|
555 |
+
out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors
|
556 |
+
assert not ctx.return_varlen_states, "return_varlen_states is not supported in backward"
|
557 |
+
dfinal_states = args[0] if ctx.return_final_states else None
|
558 |
+
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit)
|
559 |
+
return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None, None, None
|
560 |
+
|
561 |
+
|
562 |
+
def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False):
|
563 |
+
"""
|
564 |
+
Argument:
|
565 |
+
x: (batch, seqlen, nheads, headdim)
|
566 |
+
dt: (batch, seqlen, nheads)
|
567 |
+
A: (nheads)
|
568 |
+
B: (batch, seqlen, ngroups, dstate)
|
569 |
+
C: (batch, seqlen, ngroups, dstate)
|
570 |
+
chunk_size: int
|
571 |
+
D: (nheads, headdim) or (nheads,)
|
572 |
+
z: (batch, seqlen, nheads, headdim)
|
573 |
+
dt_bias: (nheads,)
|
574 |
+
initial_states: (batch, nheads, headdim, dstate)
|
575 |
+
seq_idx: (batch, seqlen)
|
576 |
+
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
577 |
+
dt_softplus: Whether to apply softplus to dt
|
578 |
+
Return:
|
579 |
+
out: (batch, seqlen, nheads, headdim)
|
580 |
+
"""
|
581 |
+
return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states)
|
582 |
+
|
583 |
+
|
584 |
+
def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):
|
585 |
+
"""
|
586 |
+
Argument:
|
587 |
+
x: (batch, seqlen, nheads, headdim)
|
588 |
+
dt: (batch, seqlen, nheads)
|
589 |
+
A: (nheads)
|
590 |
+
B: (batch, seqlen, ngroups, dstate)
|
591 |
+
C: (batch, seqlen, ngroups, dstate)
|
592 |
+
D: (nheads, headdim) or (nheads,)
|
593 |
+
z: (batch, seqlen, nheads, headdim)
|
594 |
+
dt_bias: (nheads,)
|
595 |
+
Return:
|
596 |
+
out: (batch, seqlen, nheads, headdim)
|
597 |
+
"""
|
598 |
+
batch, seqlen, nheads, headdim = x.shape
|
599 |
+
dstate = B.shape[-1]
|
600 |
+
if seqlen % chunk_size != 0:
|
601 |
+
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
602 |
+
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
603 |
+
dt = dt.float() # We want high precision for this before cumsum
|
604 |
+
if dt_bias is not None:
|
605 |
+
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
606 |
+
if dt_softplus:
|
607 |
+
dt = F.softplus(dt)
|
608 |
+
dA = dt * rearrange(A, "h -> h 1 1")
|
609 |
+
dA = dt * rearrange(A, "h -> h 1 1")
|
610 |
+
dA_cumsum = torch.cumsum(dA, dim=-1)
|
611 |
+
# 1. Compute the state for each chunk
|
612 |
+
states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
|
613 |
+
# 2. Pass the state to all the chunks by weighted cumsum.
|
614 |
+
states = rearrange(state_passing(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0],
|
615 |
+
"... (p n) -> ... p n", n=dstate)
|
616 |
+
# 3. Compute the output for each chunk
|
617 |
+
out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
618 |
+
return out
|
619 |
+
|
620 |
+
|
621 |
+
def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):
|
622 |
+
"""
|
623 |
+
Argument:
|
624 |
+
x: (batch, seqlen, nheads, headdim)
|
625 |
+
dt: (batch, seqlen, nheads)
|
626 |
+
A: (nheads)
|
627 |
+
B: (batch, seqlen, ngroups, dstate)
|
628 |
+
C: (batch, seqlen, ngroups, dstate)
|
629 |
+
D: (nheads, headdim) or (nheads,)
|
630 |
+
z: (batch, seqlen, nheads, headdim)
|
631 |
+
dt_bias: (nheads,)
|
632 |
+
Return:
|
633 |
+
out: (batch, seqlen, nheads, headdim)
|
634 |
+
"""
|
635 |
+
batch, seqlen, nheads, headdim = x.shape
|
636 |
+
dstate = B.shape[-1]
|
637 |
+
if seqlen % chunk_size != 0:
|
638 |
+
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
639 |
+
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
640 |
+
dt = dt.float() # We want high precision for this before cumsum
|
641 |
+
if dt_bias is not None:
|
642 |
+
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
643 |
+
if dt_softplus:
|
644 |
+
dt = F.softplus(dt)
|
645 |
+
dA = dt * rearrange(A, "h -> h 1 1")
|
646 |
+
dA_cumsum = torch.cumsum(dA, dim=-1)
|
647 |
+
# 1. Compute the state for each chunk
|
648 |
+
states = chunk_state_ref(B, x, dt, dA_cumsum)
|
649 |
+
states_dtype = states.dtype
|
650 |
+
if states.dtype not in [torch.float32, torch.float64]:
|
651 |
+
states = states.to(torch.float32)
|
652 |
+
# 2. Pass the state to all the chunks by weighted cumsum.
|
653 |
+
# state_passing_ref is much less numerically stable
|
654 |
+
states = rearrange(state_passing_ref(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0],
|
655 |
+
"... (p n) -> ... p n", n=dstate)
|
656 |
+
states = states.to(states_dtype)
|
657 |
+
# 3. Compute the output for each chunk
|
658 |
+
out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
659 |
+
return out
|
660 |
+
|
661 |
+
|
662 |
+
def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
|
663 |
+
"""
|
664 |
+
Argument:
|
665 |
+
x: (batch, seqlen, nheads, headdim)
|
666 |
+
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
667 |
+
A: (nheads) or (dim, dstate)
|
668 |
+
B: (batch, seqlen, ngroups, dstate)
|
669 |
+
C: (batch, seqlen, ngroups, dstate)
|
670 |
+
D: (nheads, headdim) or (nheads,)
|
671 |
+
z: (batch, seqlen, nheads, headdim)
|
672 |
+
dt_bias: (nheads,) or (nheads, headdim)
|
673 |
+
Return:
|
674 |
+
out: (batch, seqlen, nheads, headdim)
|
675 |
+
"""
|
676 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
677 |
+
|
678 |
+
batch, seqlen, nheads, headdim = x.shape
|
679 |
+
_, _, ngroups, dstate = B.shape
|
680 |
+
x = rearrange(x, "b l h p -> b (h p) l")
|
681 |
+
if dt.dim() == 3:
|
682 |
+
dt = repeat(dt, "b l h -> b l h p", p=headdim)
|
683 |
+
dt = rearrange(dt, "b l h p -> b (h p) l")
|
684 |
+
if A.dim() == 1:
|
685 |
+
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
686 |
+
else:
|
687 |
+
A = A.to(dtype=torch.float32)
|
688 |
+
B = rearrange(B, "b l g n -> b g n l")
|
689 |
+
C = rearrange(C, "b l g n -> b g n l")
|
690 |
+
if D is not None:
|
691 |
+
if D.dim() == 2:
|
692 |
+
D = rearrange(D, "h p -> (h p)")
|
693 |
+
else:
|
694 |
+
D = repeat(D, "h -> (h p)", p=headdim)
|
695 |
+
if z is not None:
|
696 |
+
z = rearrange(z, "b l h p -> b (h p) l")
|
697 |
+
if dt_bias is not None:
|
698 |
+
if dt_bias.dim() == 1:
|
699 |
+
dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
|
700 |
+
dt_bias = rearrange(dt_bias, "h p -> (h p)")
|
701 |
+
if dt_limit != (0.0, float("inf")):
|
702 |
+
if dt_bias is not None:
|
703 |
+
dt = dt + rearrange(dt_bias, "d -> d 1")
|
704 |
+
if dt_softplus:
|
705 |
+
dt = F.softplus(dt)
|
706 |
+
dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
|
707 |
+
dt_bias = None
|
708 |
+
dt_softplus = None
|
709 |
+
out = selective_scan_fn(x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus)
|
710 |
+
return rearrange(out, "b (h p) l -> b l h p", p=headdim)
|
711 |
+
|
712 |
+
|
713 |
+
def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=None, z=None,
|
714 |
+
dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")),
|
715 |
+
activation="silu", headdim=None, ngroups=1):
|
716 |
+
"""
|
717 |
+
Argument:
|
718 |
+
xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
|
719 |
+
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
720 |
+
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
721 |
+
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
722 |
+
A: (nheads)
|
723 |
+
D: (nheads, headdim) or (nheads,)
|
724 |
+
z: (batch, seqlen, dim)
|
725 |
+
dt_bias: (nheads) or (nheads, headdim)
|
726 |
+
headdim: if D is 1D and z is None, headdim must be passed in
|
727 |
+
Return:
|
728 |
+
out: (batch, seqlen, dim)
|
729 |
+
"""
|
730 |
+
batch, seqlen, nheads = dt.shape[:3]
|
731 |
+
assert nheads % ngroups == 0
|
732 |
+
if z is not None:
|
733 |
+
dim = z.shape[-1]
|
734 |
+
assert dim % nheads == 0
|
735 |
+
headdim = dim // nheads
|
736 |
+
else:
|
737 |
+
if D.dim() == 1:
|
738 |
+
assert headdim is not None
|
739 |
+
else:
|
740 |
+
headdim = D.shape[1]
|
741 |
+
dim = nheads * headdim
|
742 |
+
xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation),
|
743 |
+
"b d s -> b s d")
|
744 |
+
dstate = (xBC.shape[-1] - dim) // ngroups // 2
|
745 |
+
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
746 |
+
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
747 |
+
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
748 |
+
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
749 |
+
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
750 |
+
out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
|
751 |
+
return rearrange(out, "b s h p -> b s (h p)")
|
752 |
+
|
753 |
+
|
754 |
+
class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
|
755 |
+
|
756 |
+
@staticmethod
|
757 |
+
@custom_fwd
|
758 |
+
def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
|
759 |
+
rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None,
|
760 |
+
ngroups=1, norm_before_gate=True):
|
761 |
+
assert activation in [None, "silu", "swish"]
|
762 |
+
if D.dim() == 1:
|
763 |
+
assert headdim is not None
|
764 |
+
nheads, = D.shape
|
765 |
+
else:
|
766 |
+
nheads, headdim = D.shape
|
767 |
+
batch, seqlen, _ = zxbcdt.shape
|
768 |
+
dim = nheads * headdim
|
769 |
+
assert nheads % ngroups == 0
|
770 |
+
dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
|
771 |
+
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
|
772 |
+
assert d_nonssm >= 0
|
773 |
+
assert zxbcdt.shape == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads)
|
774 |
+
assert dt_bias.shape == (nheads,)
|
775 |
+
assert A.shape == (nheads,)
|
776 |
+
zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1)
|
777 |
+
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
778 |
+
xBC_conv = rearrange(
|
779 |
+
causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
|
780 |
+
conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
|
781 |
+
"b d s -> b s d"
|
782 |
+
)
|
783 |
+
x, B, C = torch.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
784 |
+
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
785 |
+
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
786 |
+
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
787 |
+
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
788 |
+
if rmsnorm_weight is None:
|
789 |
+
out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
|
790 |
+
out = rearrange(out, "b s h p -> b s (h p)")
|
791 |
+
rstd = None
|
792 |
+
if d_nonssm > 0:
|
793 |
+
out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
|
794 |
+
else:
|
795 |
+
out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
|
796 |
+
# reshape input data into 2D tensor
|
797 |
+
x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
|
798 |
+
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
799 |
+
rmsnorm_weight = rmsnorm_weight.contiguous()
|
800 |
+
if d_nonssm == 0:
|
801 |
+
out = None
|
802 |
+
else:
|
803 |
+
out01 = torch.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype, device=x_rms.device)
|
804 |
+
out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
|
805 |
+
_swiglu_fwd(zx0, out=out01[..., :d_nonssm])
|
806 |
+
out, _, rstd = _layer_norm_fwd(x_rms, rmsnorm_weight, None, rmsnorm_eps, z_rms, out=out,
|
807 |
+
group_size=dim // ngroups,
|
808 |
+
norm_before_gate=norm_before_gate, is_rms_norm=True)
|
809 |
+
if d_nonssm == 0:
|
810 |
+
out = rearrange(out, "(b s) d -> b s d", b=batch)
|
811 |
+
else:
|
812 |
+
out = out01
|
813 |
+
ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None
|
814 |
+
if outproj_weight is not None:
|
815 |
+
if torch.is_autocast_enabled():
|
816 |
+
dtype = torch.get_autocast_gpu_dtype()
|
817 |
+
out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
|
818 |
+
outproj_bias = outproj_bias.to(dtype) if outproj_bias is not None else None
|
819 |
+
out = F.linear(out, outproj_weight, outproj_bias)
|
820 |
+
else:
|
821 |
+
assert outproj_bias is None
|
822 |
+
ctx.save_for_backward(zxbcdt, conv1d_weight, conv1d_bias,
|
823 |
+
out_x, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias)
|
824 |
+
ctx.dt_limit = dt_limit
|
825 |
+
ctx.return_final_states = return_final_states
|
826 |
+
ctx.activation = activation
|
827 |
+
ctx.rmsnorm_eps = rmsnorm_eps
|
828 |
+
ctx.norm_before_gate = norm_before_gate
|
829 |
+
ctx.chunk_size = chunk_size
|
830 |
+
ctx.headdim = headdim
|
831 |
+
ctx.ngroups = ngroups
|
832 |
+
return out if not return_final_states else (out, final_states)
|
833 |
+
|
834 |
+
@staticmethod
|
835 |
+
@custom_bwd
|
836 |
+
def backward(ctx, dout, *args):
|
837 |
+
zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors
|
838 |
+
dfinal_states = args[0] if ctx.return_final_states else None
|
839 |
+
headdim = ctx.headdim
|
840 |
+
nheads = D.shape[0]
|
841 |
+
dim = nheads * headdim
|
842 |
+
assert nheads % ctx.ngroups == 0
|
843 |
+
dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
|
844 |
+
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
|
845 |
+
assert d_nonssm >= 0
|
846 |
+
recompute_output = outproj_weight is not None
|
847 |
+
if recompute_output:
|
848 |
+
out_recompute = torch.empty(*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype)
|
849 |
+
out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], dim=-1)
|
850 |
+
zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
|
851 |
+
# Recompute x, B, C
|
852 |
+
xBC_conv = rearrange(
|
853 |
+
causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
|
854 |
+
conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
|
855 |
+
"b d s -> b s d"
|
856 |
+
)
|
857 |
+
x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
|
858 |
+
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
859 |
+
B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
|
860 |
+
C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
|
861 |
+
dzxbcdt = torch.empty_like(zxbcdt)
|
862 |
+
dzx0, dz, dxBC_given, ddt_given = torch.split(dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
|
863 |
+
dxBC = torch.empty_like(xBC)
|
864 |
+
dx, dB, dC = torch.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
|
865 |
+
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
866 |
+
dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
|
867 |
+
dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
|
868 |
+
dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
|
869 |
+
if outproj_weight is not None:
|
870 |
+
dout_og = dout
|
871 |
+
dout = F.linear(dout, outproj_weight.t())
|
872 |
+
if d_nonssm > 0:
|
873 |
+
dout0, dout = dout.split([d_nonssm, dim], dim=-1)
|
874 |
+
_swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
|
875 |
+
dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
|
876 |
+
if rmsnorm_weight is None:
|
877 |
+
dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
|
878 |
+
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = _mamba_chunk_scan_combined_bwd(
|
879 |
+
dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC, dz=dz, recompute_output=recompute_output
|
880 |
+
)
|
881 |
+
out_for_linear = rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
|
882 |
+
drmsnorm_weight = None
|
883 |
+
else:
|
884 |
+
batch = dout.shape[0]
|
885 |
+
dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
|
886 |
+
dz = rearrange(dz, "b l d -> (b l) d")
|
887 |
+
x_rms = rearrange(out, "b s h p -> (b s) (h p)")
|
888 |
+
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
889 |
+
out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None
|
890 |
+
dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None)
|
891 |
+
out_for_linear = out_recompute if recompute_output else None
|
892 |
+
dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
|
893 |
+
dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
|
894 |
+
dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC
|
895 |
+
)
|
896 |
+
|
897 |
+
if outproj_weight is not None:
|
898 |
+
doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
|
899 |
+
doutproj_bias = dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
|
900 |
+
else:
|
901 |
+
doutproj_weight, doutproj_bias = None, None
|
902 |
+
dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
|
903 |
+
dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
904 |
+
rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
|
905 |
+
rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"]
|
906 |
+
)
|
907 |
+
dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
|
908 |
+
return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None
|
909 |
+
|
910 |
+
|
911 |
+
def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
|
912 |
+
"""
|
913 |
+
Argument:
|
914 |
+
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
915 |
+
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
916 |
+
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
917 |
+
dt_bias: (nheads,)
|
918 |
+
A: (nheads)
|
919 |
+
D: (nheads, headdim) or (nheads,)
|
920 |
+
initial_states: (batch, nheads, headdim, dstate)
|
921 |
+
seq_idx: (batch, seqlen), int32
|
922 |
+
rmsnorm_weight: (dim,)
|
923 |
+
outproj_weight: (out_dim, dim)
|
924 |
+
outproj_bias: (out_dim,)
|
925 |
+
headdim: if D is 1D, headdim must be passed in
|
926 |
+
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
927 |
+
Return:
|
928 |
+
out: (batch, seqlen, dim)
|
929 |
+
"""
|
930 |
+
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
|
931 |
+
|
932 |
+
|
933 |
+
def mamba_split_conv1d_scan_ref(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, dt_limit=(0.0, float("inf")), activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
|
934 |
+
"""
|
935 |
+
Argument:
|
936 |
+
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
937 |
+
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
938 |
+
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
939 |
+
dt_bias: (nheads,)
|
940 |
+
A: (nheads)
|
941 |
+
D: (nheads, headdim) or (nheads,)
|
942 |
+
rmsnorm_weight: (dim,)
|
943 |
+
outproj_weight: (out_dim, dim)
|
944 |
+
outproj_bias: (out_dim,)
|
945 |
+
headdim: if D is 1D, headdim must be passed in
|
946 |
+
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
947 |
+
Return:
|
948 |
+
out: (batch, seqlen, dim)
|
949 |
+
"""
|
950 |
+
if D.dim() == 1:
|
951 |
+
assert headdim is not None
|
952 |
+
nheads, = D.shape
|
953 |
+
else:
|
954 |
+
nheads, headdim = D.shape
|
955 |
+
assert nheads % ngroups == 0
|
956 |
+
batch, seqlen, _ = zxbcdt.shape
|
957 |
+
dim = nheads * headdim
|
958 |
+
dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
|
959 |
+
assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
|
960 |
+
assert dt_bias.shape == (nheads,)
|
961 |
+
assert A.shape == (nheads,)
|
962 |
+
if rmsnorm_weight is not None:
|
963 |
+
assert rmsnorm_weight.shape == (dim,)
|
964 |
+
z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
|
965 |
+
xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation),
|
966 |
+
"b d s -> b s d")
|
967 |
+
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
968 |
+
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
969 |
+
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
970 |
+
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
971 |
+
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
972 |
+
out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(),
|
973 |
+
z=z if rmsnorm_weight is None else None, dt_bias=dt_bias, dt_softplus=True, dt_limit=dt_limit)
|
974 |
+
out = rearrange(out, "b s h p -> b s (h p)")
|
975 |
+
if rmsnorm_weight is not None:
|
976 |
+
out = rmsnorm_fn(out, rmsnorm_weight, None, z=rearrange(z, "b l h p -> b l (h p)"), eps=rmsnorm_eps,
|
977 |
+
norm_before_gate=norm_before_gate)
|
978 |
+
if outproj_weight is not None:
|
979 |
+
out = F.linear(out, outproj_weight, outproj_bias)
|
980 |
+
return out
|
981 |
+
|
mamba/build/lib/mamba_ssm/ops/triton/ssd_state_passing.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or 2.2.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
|
16 |
+
@triton.autotune(
|
17 |
+
configs=[
|
18 |
+
triton.Config({'BLOCK_SIZE': 64}),
|
19 |
+
triton.Config({'BLOCK_SIZE': 128}),
|
20 |
+
triton.Config({'BLOCK_SIZE': 256}),
|
21 |
+
triton.Config({'BLOCK_SIZE': 512}),
|
22 |
+
triton.Config({'BLOCK_SIZE': 1024}),
|
23 |
+
triton.Config({'BLOCK_SIZE': 2048}),
|
24 |
+
],
|
25 |
+
key=['dim'],
|
26 |
+
)
|
27 |
+
@triton.jit
|
28 |
+
def _state_passing_fwd_kernel(
|
29 |
+
# Pointers to matrices
|
30 |
+
states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,
|
31 |
+
# Matrix dimensions
|
32 |
+
dim, nchunks, seqlen, chunk_size,
|
33 |
+
# Strides
|
34 |
+
stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,
|
35 |
+
stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
|
36 |
+
stride_final_states_batch, stride_final_states_head, stride_final_states_dim,
|
37 |
+
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
|
38 |
+
stride_initstates_batch, stride_initstates_head, stride_initstates_dim,
|
39 |
+
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
40 |
+
# Meta-parameters
|
41 |
+
HAS_INITSTATES: tl.constexpr,
|
42 |
+
HAS_SEQ_IDX: tl.constexpr,
|
43 |
+
BLOCK_SIZE: tl.constexpr,
|
44 |
+
):
|
45 |
+
pid_b = tl.program_id(axis=1)
|
46 |
+
pid_h = tl.program_id(axis=2)
|
47 |
+
pid_m = tl.program_id(axis=0)
|
48 |
+
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
49 |
+
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
|
50 |
+
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
51 |
+
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
|
52 |
+
if HAS_INITSTATES:
|
53 |
+
initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head
|
54 |
+
if HAS_SEQ_IDX:
|
55 |
+
seq_idx_ptr += pid_b * stride_seq_idx_batch
|
56 |
+
|
57 |
+
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
58 |
+
states_ptrs = states_ptr + offs_m * stride_states_dim
|
59 |
+
out_ptrs = out_ptr + offs_m * stride_out_dim
|
60 |
+
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
|
61 |
+
|
62 |
+
if not HAS_INITSTATES:
|
63 |
+
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
64 |
+
else:
|
65 |
+
initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim
|
66 |
+
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
67 |
+
tl.store(out_ptrs, states, mask=offs_m < dim)
|
68 |
+
out_ptrs += stride_out_chunk
|
69 |
+
seq_idx = 0
|
70 |
+
for c in range(nchunks):
|
71 |
+
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
72 |
+
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
73 |
+
scale = tl.exp(dA_cs)
|
74 |
+
if HAS_SEQ_IDX:
|
75 |
+
seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
|
76 |
+
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
|
77 |
+
seq_idx = seq_idx_new
|
78 |
+
states = scale * states + new_states
|
79 |
+
if c < nchunks - 1:
|
80 |
+
tl.store(out_ptrs, states, mask=offs_m < dim)
|
81 |
+
else:
|
82 |
+
tl.store(final_states_ptrs, states, mask=offs_m < dim)
|
83 |
+
states_ptrs += stride_states_chunk
|
84 |
+
dA_cs_ptr += stride_dA_cs_chunk
|
85 |
+
out_ptrs += stride_out_chunk
|
86 |
+
|
87 |
+
|
88 |
+
@triton.autotune(
|
89 |
+
configs=[
|
90 |
+
triton.Config({'BLOCK_SIZE': 64}),
|
91 |
+
triton.Config({'BLOCK_SIZE': 128}),
|
92 |
+
triton.Config({'BLOCK_SIZE': 256}),
|
93 |
+
triton.Config({'BLOCK_SIZE': 512}),
|
94 |
+
triton.Config({'BLOCK_SIZE': 1024}),
|
95 |
+
triton.Config({'BLOCK_SIZE': 2048}),
|
96 |
+
],
|
97 |
+
key=['dim'],
|
98 |
+
)
|
99 |
+
@triton.jit
|
100 |
+
def _state_passing_bwd_kernel(
|
101 |
+
# Pointers to matrices
|
102 |
+
dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,
|
103 |
+
dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,
|
104 |
+
# Matrix dimensions
|
105 |
+
dim, nchunks, seqlen, chunk_size,
|
106 |
+
# Strides
|
107 |
+
stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,
|
108 |
+
stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
|
109 |
+
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
|
110 |
+
stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,
|
111 |
+
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
112 |
+
stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,
|
113 |
+
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,
|
114 |
+
stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,
|
115 |
+
# Meta-parameters
|
116 |
+
CONVERT_STATES: tl.constexpr,
|
117 |
+
HAS_DFINAL_STATES: tl.constexpr,
|
118 |
+
HAS_DINITSTATES: tl.constexpr,
|
119 |
+
HAS_SEQ_IDX: tl.constexpr,
|
120 |
+
BLOCK_SIZE: tl.constexpr,
|
121 |
+
):
|
122 |
+
pid_b = tl.program_id(axis=1)
|
123 |
+
pid_h = tl.program_id(axis=2)
|
124 |
+
pid_m = tl.program_id(axis=0)
|
125 |
+
dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk
|
126 |
+
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk
|
127 |
+
ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m
|
128 |
+
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
|
129 |
+
dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk
|
130 |
+
if CONVERT_STATES:
|
131 |
+
states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
|
132 |
+
if HAS_DFINAL_STATES:
|
133 |
+
dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head
|
134 |
+
if HAS_DINITSTATES:
|
135 |
+
dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head
|
136 |
+
if HAS_SEQ_IDX:
|
137 |
+
seq_idx_ptr += pid_b * stride_seq_idx_batch
|
138 |
+
|
139 |
+
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
140 |
+
dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim
|
141 |
+
out_ptrs = out_ptr + offs_m * stride_out_dim
|
142 |
+
dout_ptrs = dout_ptr + offs_m * stride_dout_dim
|
143 |
+
if CONVERT_STATES:
|
144 |
+
states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim
|
145 |
+
|
146 |
+
if HAS_DFINAL_STATES:
|
147 |
+
dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)
|
148 |
+
else:
|
149 |
+
dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
150 |
+
tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
|
151 |
+
if HAS_SEQ_IDX:
|
152 |
+
seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)
|
153 |
+
dstates_ptrs -= stride_dstates_chunk
|
154 |
+
for c in range(nchunks - 1):
|
155 |
+
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
156 |
+
scale = tl.exp(dA_cs)
|
157 |
+
if HAS_SEQ_IDX:
|
158 |
+
seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))
|
159 |
+
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
|
160 |
+
seq_idx = seq_idx_new
|
161 |
+
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
162 |
+
if CONVERT_STATES:
|
163 |
+
tl.store(states_converted_ptrs, out, mask=offs_m < dim)
|
164 |
+
ddA = tl.sum(out * dstates) * scale
|
165 |
+
tl.store(ddA_cs_ptr, ddA)
|
166 |
+
dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
167 |
+
dstates = scale * dstates + dout
|
168 |
+
tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
|
169 |
+
dout_ptrs -= stride_dout_chunk
|
170 |
+
dstates_ptrs -= stride_dstates_chunk
|
171 |
+
dA_cs_ptr -= stride_dA_cs_chunk
|
172 |
+
ddA_cs_ptr -= stride_ddA_cs_chunk
|
173 |
+
out_ptrs -= stride_out_chunk
|
174 |
+
if CONVERT_STATES:
|
175 |
+
states_converted_ptrs -= stride_out_chunk
|
176 |
+
if CONVERT_STATES:
|
177 |
+
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
178 |
+
tl.store(states_converted_ptrs, out, mask=offs_m < dim)
|
179 |
+
if not HAS_DINITSTATES:
|
180 |
+
tl.store(ddA_cs_ptr, 0.0)
|
181 |
+
else:
|
182 |
+
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
183 |
+
scale = tl.exp(dA_cs)
|
184 |
+
if HAS_SEQ_IDX:
|
185 |
+
scale = tl.where(seq_idx == 0, scale, 0.0)
|
186 |
+
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
187 |
+
ddA = tl.sum(out * dstates) * scale
|
188 |
+
tl.store(ddA_cs_ptr, ddA)
|
189 |
+
dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
190 |
+
dstates = scale * dstates + dout
|
191 |
+
tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)
|
192 |
+
|
193 |
+
|
194 |
+
def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,
|
195 |
+
out_dtype=None):
|
196 |
+
batch, nchunks, nheads, dim = states.shape
|
197 |
+
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
198 |
+
if initial_states is not None:
|
199 |
+
assert initial_states.shape == (batch, nheads, dim)
|
200 |
+
if seq_idx is not None:
|
201 |
+
assert chunk_size is not None
|
202 |
+
seqlen = seq_idx.shape[-1]
|
203 |
+
assert seq_idx.shape == (batch, seqlen)
|
204 |
+
out_dtype = states.dtype if out_dtype is None else out_dtype
|
205 |
+
out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
|
206 |
+
final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)
|
207 |
+
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
|
208 |
+
with torch.cuda.device(states.device.index):
|
209 |
+
_state_passing_fwd_kernel[grid](
|
210 |
+
states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,
|
211 |
+
dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
|
212 |
+
states.stride(0), states.stride(1), states.stride(2), states.stride(3),
|
213 |
+
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
|
214 |
+
final_states.stride(0), final_states.stride(1), final_states.stride(2),
|
215 |
+
dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
|
216 |
+
*((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
|
217 |
+
if initial_states is not None else (0, 0, 0)),
|
218 |
+
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
219 |
+
HAS_INITSTATES=initial_states is not None,
|
220 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
221 |
+
)
|
222 |
+
return out, final_states
|
223 |
+
|
224 |
+
|
225 |
+
def _state_passing_bwd(
|
226 |
+
states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,
|
227 |
+
dstates_dtype=None, states_dtype=None, chunk_size=None
|
228 |
+
):
|
229 |
+
"""
|
230 |
+
states contains the initial_states at index 0. The final states are not included in states.
|
231 |
+
"""
|
232 |
+
batch, nchunks, nheads, dim = states.shape
|
233 |
+
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
234 |
+
assert dout.shape == (batch, nchunks, nheads, dim)
|
235 |
+
if seq_idx is not None:
|
236 |
+
assert chunk_size is not None
|
237 |
+
seqlen = seq_idx.shape[-1]
|
238 |
+
assert seq_idx.shape == (batch, seqlen)
|
239 |
+
dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
|
240 |
+
if states_dtype is not None and states_dtype != states.dtype:
|
241 |
+
states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
|
242 |
+
assert states_converted.stride() == states.stride()
|
243 |
+
else:
|
244 |
+
states_converted = None
|
245 |
+
if has_initial_states:
|
246 |
+
dinitstates = torch.empty_like(dstates[:, 0])
|
247 |
+
else:
|
248 |
+
dinitstates = None
|
249 |
+
if dfinal_states is not None:
|
250 |
+
assert dfinal_states.shape == (batch, nheads, dim)
|
251 |
+
BLOCK_SIZE_min = 64
|
252 |
+
n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min
|
253 |
+
ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,
|
254 |
+
dtype=torch.float32, device=dA_chunk_cumsum.device)
|
255 |
+
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
|
256 |
+
with torch.cuda.device(dout.device.index):
|
257 |
+
_state_passing_bwd_kernel[grid](
|
258 |
+
dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,
|
259 |
+
dstates, ddA_chunk_cumsum, dinitstates, states_converted,
|
260 |
+
dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
|
261 |
+
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
|
262 |
+
states.stride(0), states.stride(1), states.stride(2), states.stride(3),
|
263 |
+
dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
|
264 |
+
*((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))
|
265 |
+
if dfinal_states is not None else (0, 0, 0)),
|
266 |
+
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
267 |
+
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),
|
268 |
+
ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),
|
269 |
+
*((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))
|
270 |
+
if dinitstates is not None else (0, 0, 0)),
|
271 |
+
CONVERT_STATES=states_converted is not None,
|
272 |
+
HAS_DFINAL_STATES=dfinal_states is not None,
|
273 |
+
HAS_DINITSTATES=dinitstates is not None,
|
274 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
275 |
+
)
|
276 |
+
BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"]
|
277 |
+
n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
|
278 |
+
ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)
|
279 |
+
if states_dtype is not None and states_dtype == states.dtype:
|
280 |
+
states_converted = states
|
281 |
+
return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)
|
282 |
+
|
283 |
+
|
284 |
+
class StatePassingFn(torch.autograd.Function):
|
285 |
+
|
286 |
+
@staticmethod
|
287 |
+
def forward(ctx, states, dA_chunk_cumsum, initial_states=None):
|
288 |
+
batch, nchunks, nheads, dim = states.shape
|
289 |
+
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
290 |
+
if states.stride(-1) != 1:
|
291 |
+
states = states.contiguous()
|
292 |
+
out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states)
|
293 |
+
ctx.save_for_backward(out, dA_chunk_cumsum)
|
294 |
+
ctx.has_initial_states = initial_states is not None
|
295 |
+
return out, final_states
|
296 |
+
|
297 |
+
@staticmethod
|
298 |
+
def backward(ctx, dout, dfinal_states):
|
299 |
+
out, dA_chunk_cumsum = ctx.saved_tensors
|
300 |
+
batch, nchunks, nheads, dim = out.shape
|
301 |
+
assert dout.shape == (batch, nchunks, nheads, dim)
|
302 |
+
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
303 |
+
assert dfinal_states.shape == (batch, nheads, dim)
|
304 |
+
if dout.stride(-1) != 1:
|
305 |
+
dout = dout.contiguous()
|
306 |
+
dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd(
|
307 |
+
out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states
|
308 |
+
)
|
309 |
+
return dstates, ddA_chunk_cumsum, dinitstates
|
310 |
+
|
311 |
+
|
312 |
+
def state_passing(states, dA_chunk_cumsum, initial_states=None):
|
313 |
+
"""
|
314 |
+
Argument:
|
315 |
+
states: (batch, nchunks, nheads, dim)
|
316 |
+
dA_chunk_cumsum: (batch, nheads, nchunks)
|
317 |
+
initial_states: (batch, nheads, dim)
|
318 |
+
Return:
|
319 |
+
out: (batch, nchunks, nheads, dim)
|
320 |
+
final_states: (batch, nheads, dim)
|
321 |
+
"""
|
322 |
+
return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states)
|
323 |
+
|
324 |
+
|
325 |
+
def state_passing_ref(states, dA_chunk_cumsum, initial_states=None):
|
326 |
+
"""
|
327 |
+
Argument:
|
328 |
+
states: (batch, nchunks, nheads, dim)
|
329 |
+
dA_chunk_cumsum: (batch, nheads, nchunks)
|
330 |
+
initial_states: (batch, nheads, dim)
|
331 |
+
Return:
|
332 |
+
out: (batch, nchunks, nheads, dim)
|
333 |
+
final_states: (batch, nheads, dim)
|
334 |
+
"""
|
335 |
+
if initial_states is None:
|
336 |
+
initial_states = torch.zeros_like(states[:, 0])
|
337 |
+
states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1)
|
338 |
+
dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0))
|
339 |
+
dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1)
|
340 |
+
nchunks = dA_chunk_cumsum.shape[-1]
|
341 |
+
# (batch, nheads, nchunks, nchunks)
|
342 |
+
dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :]
|
343 |
+
# (batch, nheads, nchunks, nchunks)
|
344 |
+
decay_chunk = torch.exp(dt_chunk_segment_sum)
|
345 |
+
causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0)
|
346 |
+
decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)
|
347 |
+
out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states)
|
348 |
+
return out[:, :-1], out[:, -1]
|
mamba/build/lib/mamba_ssm/utils/__init__.py
ADDED
File without changes
|
mamba/build/lib/mamba_ssm/utils/generation.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
+
import gc
|
3 |
+
import time
|
4 |
+
from collections import namedtuple
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from functools import partial
|
7 |
+
from typing import Callable, Optional, Sequence, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
from torch import Tensor
|
13 |
+
from torch.profiler import ProfilerActivity, profile, record_function
|
14 |
+
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class InferenceParams:
|
19 |
+
"""Inference parameters that are passed to the main model in order
|
20 |
+
to efficienly calculate and store the context during inference."""
|
21 |
+
|
22 |
+
max_seqlen: int
|
23 |
+
max_batch_size: int
|
24 |
+
seqlen_offset: int = 0
|
25 |
+
batch_size_offset: int = 0
|
26 |
+
key_value_memory_dict: dict = field(default_factory=dict)
|
27 |
+
lengths_per_sample: Optional[Tensor] = None
|
28 |
+
|
29 |
+
def reset(self, max_seqlen, max_batch_size):
|
30 |
+
self.max_seqlen = max_seqlen
|
31 |
+
self.max_batch_size = max_batch_size
|
32 |
+
self.seqlen_offset = 0
|
33 |
+
if self.lengths_per_sample is not None:
|
34 |
+
self.lengths_per_sample.zero_()
|
35 |
+
|
36 |
+
|
37 |
+
def modify_logits_for_min_p_filtering(logits, min_p):
|
38 |
+
"""Set the logits for none min_p values to -inf. Done in-place."""
|
39 |
+
if min_p <= 0.0 or min_p >= 1.0:
|
40 |
+
return
|
41 |
+
indices_to_remove = logits < min_p
|
42 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
43 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
44 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
45 |
+
def modify_logits_for_top_k_filtering(logits, top_k):
|
46 |
+
"""Set the logits for none top-k values to -inf. Done in-place."""
|
47 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
48 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
49 |
+
|
50 |
+
|
51 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
52 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
53 |
+
def modify_logits_for_top_p_filtering(logits, top_p):
|
54 |
+
"""Set the logits for none top-p values to -inf. Done in-place."""
|
55 |
+
if top_p <= 0.0 or top_p >= 1.0:
|
56 |
+
return
|
57 |
+
# First sort and calculate cumulative sum of probabilities.
|
58 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
59 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
60 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
61 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
62 |
+
# scatter sorted tensors to original indexing
|
63 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
64 |
+
1, sorted_indices, sorted_indices_to_remove
|
65 |
+
)
|
66 |
+
logits.masked_fill_(indices_to_remove, float("-inf"))
|
67 |
+
|
68 |
+
|
69 |
+
def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
|
70 |
+
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
|
71 |
+
logits: (batch_size, vocab_size)
|
72 |
+
prev_output_tokens: (batch_size, seq_len)
|
73 |
+
"""
|
74 |
+
if repetition_penalty == 1.0:
|
75 |
+
return logits
|
76 |
+
score = torch.gather(logits, 1, prev_output_tokens)
|
77 |
+
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
78 |
+
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
79 |
+
logits.scatter_(1, prev_output_tokens, score)
|
80 |
+
return logits
|
81 |
+
|
82 |
+
|
83 |
+
def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
|
84 |
+
"""Sample from top-k logits.
|
85 |
+
Arguments:
|
86 |
+
logits: Tensor of shape (batch_size, vocab_size)
|
87 |
+
"""
|
88 |
+
if top_k == 1: # Short-circuit for greedy decoding
|
89 |
+
return logits.argmax(dim=-1)
|
90 |
+
else:
|
91 |
+
if top_p > 0.0:
|
92 |
+
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
93 |
+
if top_k > 0:
|
94 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
95 |
+
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
96 |
+
if temperature != 1.0:
|
97 |
+
logits_top /= temperature
|
98 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
99 |
+
return indices[
|
100 |
+
torch.arange(indices.shape[0], device=indices.device),
|
101 |
+
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
|
102 |
+
]
|
103 |
+
else:
|
104 |
+
if min_p > 0.0:
|
105 |
+
logits_top = logits.clone()
|
106 |
+
max_prob = logits_top[..., 0].item()
|
107 |
+
min_prob = max_prob * min_p
|
108 |
+
modify_logits_for_min_p_filtering(logits_top, min_prob)
|
109 |
+
if temperature != 1.0:
|
110 |
+
logits_top /= temperature
|
111 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
112 |
+
# Clone so that when we modify for top_p we don't change the original logits
|
113 |
+
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
114 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
115 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
|
116 |
+
dim=-1
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
@torch.inference_mode()
|
121 |
+
def decode(
|
122 |
+
input_ids,
|
123 |
+
model,
|
124 |
+
max_length,
|
125 |
+
top_k=1,
|
126 |
+
top_p=0.0,
|
127 |
+
min_p=0.0,
|
128 |
+
temperature=1.0,
|
129 |
+
repetition_penalty=1.0,
|
130 |
+
eos_token_id=None,
|
131 |
+
teacher_outputs=None,
|
132 |
+
vocab_size=None,
|
133 |
+
cg=False,
|
134 |
+
enable_timing=False,
|
135 |
+
streamer: Optional[TextStreamer] = None
|
136 |
+
):
|
137 |
+
"""Decoding, either greedy or with top-k or top-p sampling.
|
138 |
+
If top-k = 0, don't limit the number of candidates (pure sampling).
|
139 |
+
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
140 |
+
then top-p.
|
141 |
+
We assume that all sequences in the same batch have the same length.
|
142 |
+
|
143 |
+
Arguments:
|
144 |
+
input_ids: (batch, seq_len)
|
145 |
+
max_length: int
|
146 |
+
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
147 |
+
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
148 |
+
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
149 |
+
sequences: (batch, max_length)
|
150 |
+
scores: tuples of (batch, vocab_size)
|
151 |
+
"""
|
152 |
+
if streamer is not None:
|
153 |
+
streamer.put(input_ids.cpu())
|
154 |
+
|
155 |
+
batch_size, seqlen_og = input_ids.shape
|
156 |
+
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
157 |
+
if cg:
|
158 |
+
if not hasattr(model, "_decoding_cache"):
|
159 |
+
model._decoding_cache = None
|
160 |
+
model._decoding_cache = update_graph_cache(
|
161 |
+
model,
|
162 |
+
model._decoding_cache,
|
163 |
+
batch_size,
|
164 |
+
seqlen_og,
|
165 |
+
max_length,
|
166 |
+
)
|
167 |
+
inference_params = model._decoding_cache.inference_params
|
168 |
+
inference_params.reset(max_length, batch_size)
|
169 |
+
else:
|
170 |
+
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
171 |
+
|
172 |
+
def get_logits(input_ids, inference_params):
|
173 |
+
decoding = inference_params.seqlen_offset > 0
|
174 |
+
if decoding:
|
175 |
+
position_ids = torch.full(
|
176 |
+
(batch_size, 1),
|
177 |
+
inference_params.seqlen_offset,
|
178 |
+
dtype=torch.long,
|
179 |
+
device=input_ids.device,
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
position_ids = None
|
183 |
+
if not cg or not decoding:
|
184 |
+
logits = model(
|
185 |
+
input_ids,
|
186 |
+
position_ids=position_ids,
|
187 |
+
inference_params=inference_params,
|
188 |
+
num_last_tokens=1,
|
189 |
+
).logits.squeeze(dim=1)
|
190 |
+
else:
|
191 |
+
logits = model._decoding_cache.run(
|
192 |
+
input_ids, position_ids, inference_params.seqlen_offset
|
193 |
+
).squeeze(dim=1)
|
194 |
+
return logits[..., :vocab_size] if vocab_size is not None else logits
|
195 |
+
|
196 |
+
def sample_tokens(logits, inference_params):
|
197 |
+
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
|
198 |
+
token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
|
199 |
+
else:
|
200 |
+
token = teacher_outputs[:, inference_params.seqlen_offset]
|
201 |
+
# return rearrange(token, "b -> b 1")
|
202 |
+
return token.unsqueeze(1)
|
203 |
+
|
204 |
+
def should_stop(current_token, inference_params):
|
205 |
+
if inference_params.seqlen_offset == 0:
|
206 |
+
return False
|
207 |
+
if eos_token_id is not None and (current_token == eos_token_id).all():
|
208 |
+
return True
|
209 |
+
if inference_params.seqlen_offset >= max_length - 1:
|
210 |
+
return True
|
211 |
+
return False
|
212 |
+
|
213 |
+
start = torch.cuda.Event(enable_timing=enable_timing)
|
214 |
+
end = torch.cuda.Event(enable_timing=enable_timing)
|
215 |
+
|
216 |
+
if enable_timing:
|
217 |
+
start.record()
|
218 |
+
scores, sequences = [], [input_ids]
|
219 |
+
sequences_cat = input_ids
|
220 |
+
while not should_stop(sequences[-1], inference_params):
|
221 |
+
scores.append(get_logits(sequences[-1], inference_params))
|
222 |
+
inference_params.seqlen_offset += sequences[-1].shape[1]
|
223 |
+
if repetition_penalty == 1.0:
|
224 |
+
sampled_tokens = sample_tokens(scores[-1], inference_params)
|
225 |
+
else:
|
226 |
+
logits = modify_logit_for_repetition_penalty(
|
227 |
+
scores[-1].clone(), sequences_cat, repetition_penalty
|
228 |
+
)
|
229 |
+
sampled_tokens = sample_tokens(logits, inference_params)
|
230 |
+
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
|
231 |
+
sequences.append(sampled_tokens)
|
232 |
+
if streamer is not None:
|
233 |
+
streamer.put(sampled_tokens.cpu())
|
234 |
+
if streamer is not None:
|
235 |
+
streamer.end()
|
236 |
+
if enable_timing:
|
237 |
+
end.record()
|
238 |
+
torch.cuda.synchronize()
|
239 |
+
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
|
240 |
+
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
241 |
+
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
242 |
+
|
243 |
+
|
244 |
+
class GenerationMixin:
|
245 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
246 |
+
raise NotImplementedError
|
247 |
+
|
248 |
+
def generate(
|
249 |
+
self,
|
250 |
+
input_ids,
|
251 |
+
max_length,
|
252 |
+
top_k=1,
|
253 |
+
top_p=0.0,
|
254 |
+
min_p=0.0,
|
255 |
+
temperature=1.0,
|
256 |
+
return_dict_in_generate=False,
|
257 |
+
output_scores=False,
|
258 |
+
**kwargs,
|
259 |
+
):
|
260 |
+
output = decode(
|
261 |
+
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
|
262 |
+
)
|
263 |
+
if not output_scores:
|
264 |
+
output.scores = None
|
265 |
+
return output if return_dict_in_generate else output.sequences
|
266 |
+
|
267 |
+
|
268 |
+
@dataclass
|
269 |
+
class DecodingCGCache:
|
270 |
+
max_batch_size: int = 0
|
271 |
+
max_seqlen: int = 0
|
272 |
+
device = None
|
273 |
+
dtype = None
|
274 |
+
callables: dict = field(default_factory=dict)
|
275 |
+
mempool = None
|
276 |
+
inference_params: Optional[InferenceParams] = None
|
277 |
+
run: Optional[Callable] = None
|
278 |
+
|
279 |
+
|
280 |
+
@torch.inference_mode()
|
281 |
+
def update_graph_cache(
|
282 |
+
model,
|
283 |
+
cache,
|
284 |
+
batch_size,
|
285 |
+
seqlen_og,
|
286 |
+
max_seqlen,
|
287 |
+
decoding_seqlens=(1,),
|
288 |
+
dtype=None,
|
289 |
+
n_warmups=2,
|
290 |
+
):
|
291 |
+
if cache is None:
|
292 |
+
cache = DecodingCGCache()
|
293 |
+
param_example = next(iter(model.parameters()))
|
294 |
+
device = param_example.device
|
295 |
+
if dtype is None:
|
296 |
+
dtype = param_example.dtype
|
297 |
+
if (
|
298 |
+
(device, dtype) != (cache.device, cache.dtype)
|
299 |
+
or batch_size > cache.max_batch_size
|
300 |
+
or max_seqlen > cache.max_seqlen
|
301 |
+
): # Invalidate the cache
|
302 |
+
cache.callables = {}
|
303 |
+
cache.mempool = None
|
304 |
+
cache.inference_params = None
|
305 |
+
gc.collect()
|
306 |
+
cache.device, cache.dtype = device, dtype
|
307 |
+
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
308 |
+
assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
|
309 |
+
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
310 |
+
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
311 |
+
cache.inference_params = InferenceParams(
|
312 |
+
max_seqlen=max_seqlen,
|
313 |
+
max_batch_size=batch_size,
|
314 |
+
seqlen_offset=seqlen_og,
|
315 |
+
key_value_memory_dict=inf_cache,
|
316 |
+
lengths_per_sample=lengths_per_sample,
|
317 |
+
)
|
318 |
+
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
319 |
+
for decoding_seqlen in decoding_seqlens:
|
320 |
+
if (batch_size, decoding_seqlen) not in cache.callables:
|
321 |
+
cache.callables[batch_size, decoding_seqlen] = capture_graph(
|
322 |
+
model,
|
323 |
+
cache.inference_params,
|
324 |
+
batch_size,
|
325 |
+
max_seqlen,
|
326 |
+
decoding_seqlen=decoding_seqlen,
|
327 |
+
mempool=cache.mempool,
|
328 |
+
n_warmups=n_warmups,
|
329 |
+
)
|
330 |
+
|
331 |
+
def dispatch(input_ids, position_ids, seqlen):
|
332 |
+
batch_size, decoding_seqlen = input_ids.shape[:2]
|
333 |
+
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
|
334 |
+
|
335 |
+
cache.run = dispatch
|
336 |
+
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
|
337 |
+
return cache
|
338 |
+
|
339 |
+
|
340 |
+
def capture_graph(
|
341 |
+
model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
|
342 |
+
):
|
343 |
+
device = next(iter(model.parameters())).device
|
344 |
+
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
345 |
+
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
346 |
+
seqlen_offset_og = inference_params.seqlen_offset
|
347 |
+
inference_params.seqlen_offset = max_seqlen - decoding_seqlen
|
348 |
+
inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
|
349 |
+
|
350 |
+
# Warmup before capture
|
351 |
+
s = torch.cuda.Stream()
|
352 |
+
s.wait_stream(torch.cuda.current_stream())
|
353 |
+
with torch.cuda.stream(s):
|
354 |
+
for _ in range(n_warmups):
|
355 |
+
logits = model(
|
356 |
+
input_ids,
|
357 |
+
position_ids=position_ids,
|
358 |
+
inference_params=inference_params,
|
359 |
+
num_last_tokens=decoding_seqlen,
|
360 |
+
).logits
|
361 |
+
s.synchronize()
|
362 |
+
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
363 |
+
# which requires that graph launch and non-captured launch to not overlap (I think,
|
364 |
+
# that's how I interpret the documentation). I'm not sure if this is required.
|
365 |
+
if torch.distributed.is_initialized():
|
366 |
+
torch.distributed.barrier()
|
367 |
+
torch.cuda.current_stream().wait_stream(s)
|
368 |
+
# Captures the graph
|
369 |
+
# To allow capture, automatically sets a side stream as the current stream in the context
|
370 |
+
graph = torch.cuda.CUDAGraph()
|
371 |
+
with torch.cuda.graph(graph, pool=mempool):
|
372 |
+
logits = model(
|
373 |
+
input_ids,
|
374 |
+
position_ids=position_ids,
|
375 |
+
inference_params=inference_params,
|
376 |
+
num_last_tokens=decoding_seqlen,
|
377 |
+
).logits
|
378 |
+
|
379 |
+
def run(new_input_ids, new_position_ids, seqlen):
|
380 |
+
inference_params.lengths_per_sample[:] = seqlen
|
381 |
+
input_ids.copy_(new_input_ids)
|
382 |
+
position_ids.copy_(new_position_ids)
|
383 |
+
graph.replay()
|
384 |
+
return logits.clone()
|
385 |
+
|
386 |
+
inference_params.seqlen_offset = seqlen_offset_og
|
387 |
+
return run
|
mamba/build/lib/mamba_ssm/utils/hf.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
6 |
+
from transformers.utils.hub import cached_file
|
7 |
+
|
8 |
+
|
9 |
+
def load_config_hf(model_name):
|
10 |
+
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
|
11 |
+
return json.load(open(resolved_archive_file))
|
12 |
+
|
13 |
+
|
14 |
+
def load_state_dict_hf(model_name, device=None, dtype=None):
|
15 |
+
# If not fp32, then we don't want to load directly to the GPU
|
16 |
+
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
|
17 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
18 |
+
return torch.load(resolved_archive_file, map_location=mapped_device)
|
19 |
+
# Convert dtype before moving to GPU to save memory
|
20 |
+
if dtype is not None:
|
21 |
+
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
|
22 |
+
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
23 |
+
return state_dict
|
mamba/csrc/selective_scan/reverse_scan.cuh
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#ifndef USE_ROCM
|
8 |
+
#include <cub/config.cuh>
|
9 |
+
|
10 |
+
#include <cub/util_ptx.cuh>
|
11 |
+
#include <cub/util_type.cuh>
|
12 |
+
#include <cub/block/block_raking_layout.cuh>
|
13 |
+
// #include <cub/detail/uninitialized_copy.cuh>
|
14 |
+
#else
|
15 |
+
#include <hipcub/hipcub.hpp>
|
16 |
+
namespace cub = hipcub;
|
17 |
+
#endif
|
18 |
+
#include "uninitialized_copy.cuh"
|
19 |
+
|
20 |
+
/**
|
21 |
+
* Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned.
|
22 |
+
*/
|
23 |
+
template <
|
24 |
+
int LENGTH,
|
25 |
+
typename T,
|
26 |
+
typename ReductionOp>
|
27 |
+
__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
|
28 |
+
static_assert(LENGTH > 0);
|
29 |
+
T retval = input[LENGTH - 1];
|
30 |
+
#pragma unroll
|
31 |
+
for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
|
32 |
+
return retval;
|
33 |
+
}
|
34 |
+
|
35 |
+
/**
|
36 |
+
* Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
|
37 |
+
*/
|
38 |
+
template <
|
39 |
+
int LENGTH,
|
40 |
+
typename T,
|
41 |
+
typename ScanOp>
|
42 |
+
__device__ __forceinline__ T ThreadReverseScanInclusive(
|
43 |
+
const T (&input)[LENGTH],
|
44 |
+
T (&output)[LENGTH],
|
45 |
+
ScanOp scan_op,
|
46 |
+
const T postfix)
|
47 |
+
{
|
48 |
+
T inclusive = postfix;
|
49 |
+
#pragma unroll
|
50 |
+
for (int i = LENGTH - 1; i >= 0; --i) {
|
51 |
+
inclusive = scan_op(inclusive, input[i]);
|
52 |
+
output[i] = inclusive;
|
53 |
+
}
|
54 |
+
return inclusive;
|
55 |
+
}
|
56 |
+
|
57 |
+
/**
|
58 |
+
* Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
|
59 |
+
*/
|
60 |
+
template <
|
61 |
+
int LENGTH,
|
62 |
+
typename T,
|
63 |
+
typename ScanOp>
|
64 |
+
__device__ __forceinline__ T ThreadReverseScanExclusive(
|
65 |
+
const T (&input)[LENGTH],
|
66 |
+
T (&output)[LENGTH],
|
67 |
+
ScanOp scan_op,
|
68 |
+
const T postfix)
|
69 |
+
{
|
70 |
+
// Careful, output maybe be aliased to input
|
71 |
+
T exclusive = postfix;
|
72 |
+
T inclusive;
|
73 |
+
#pragma unroll
|
74 |
+
for (int i = LENGTH - 1; i >= 0; --i) {
|
75 |
+
inclusive = scan_op(exclusive, input[i]);
|
76 |
+
output[i] = exclusive;
|
77 |
+
exclusive = inclusive;
|
78 |
+
}
|
79 |
+
return inclusive;
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
/**
|
84 |
+
* \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
|
85 |
+
*
|
86 |
+
* LOGICAL_WARP_THREADS must be a power-of-two
|
87 |
+
*/
|
88 |
+
template <
|
89 |
+
typename T, ///< Data type being scanned
|
90 |
+
int LOGICAL_WARP_THREADS ///< Number of threads per logical warp
|
91 |
+
>
|
92 |
+
struct WarpReverseScan {
|
93 |
+
//---------------------------------------------------------------------
|
94 |
+
// Constants and type definitions
|
95 |
+
//---------------------------------------------------------------------
|
96 |
+
|
97 |
+
/// Whether the logical warp size and the PTX warp size coincide
|
98 |
+
|
99 |
+
// In hipcub, warp_threads is defined as HIPCUB_WARP_THREADS ::rocprim::warp_size()
|
100 |
+
// While in cub, it's defined as a macro that takes a redundant unused argument.
|
101 |
+
#ifndef USE_ROCM
|
102 |
+
#define WARP_THREADS CUB_WARP_THREADS(0)
|
103 |
+
#else
|
104 |
+
#define WARP_THREADS HIPCUB_WARP_THREADS
|
105 |
+
#endif
|
106 |
+
static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS);
|
107 |
+
/// The number of warp scan steps
|
108 |
+
static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
|
109 |
+
static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
|
110 |
+
|
111 |
+
|
112 |
+
//---------------------------------------------------------------------
|
113 |
+
// Thread fields
|
114 |
+
//---------------------------------------------------------------------
|
115 |
+
|
116 |
+
/// Lane index in logical warp
|
117 |
+
unsigned int lane_id;
|
118 |
+
|
119 |
+
/// Logical warp index in 32-thread physical warp
|
120 |
+
unsigned int warp_id;
|
121 |
+
|
122 |
+
/// 32-thread physical warp member mask of logical warp
|
123 |
+
unsigned int member_mask;
|
124 |
+
|
125 |
+
//---------------------------------------------------------------------
|
126 |
+
// Construction
|
127 |
+
//---------------------------------------------------------------------
|
128 |
+
|
129 |
+
/// Constructor
|
130 |
+
explicit __device__ __forceinline__
|
131 |
+
WarpReverseScan()
|
132 |
+
: lane_id(cub::LaneId())
|
133 |
+
, warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
|
134 |
+
, member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
|
135 |
+
{
|
136 |
+
if (!IS_ARCH_WARP) {
|
137 |
+
lane_id = lane_id % LOGICAL_WARP_THREADS;
|
138 |
+
}
|
139 |
+
}
|
140 |
+
|
141 |
+
|
142 |
+
/// Broadcast
|
143 |
+
__device__ __forceinline__ T Broadcast(
|
144 |
+
T input, ///< [in] The value to broadcast
|
145 |
+
int src_lane) ///< [in] Which warp lane is to do the broadcasting
|
146 |
+
{
|
147 |
+
return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
|
148 |
+
}
|
149 |
+
|
150 |
+
|
151 |
+
/// Inclusive scan
|
152 |
+
template <typename ScanOpT>
|
153 |
+
__device__ __forceinline__ void InclusiveReverseScan(
|
154 |
+
T input, ///< [in] Calling thread's input item.
|
155 |
+
T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
|
156 |
+
ScanOpT scan_op) ///< [in] Binary scan operator
|
157 |
+
{
|
158 |
+
inclusive_output = input;
|
159 |
+
#pragma unroll
|
160 |
+
for (int STEP = 0; STEP < STEPS; STEP++) {
|
161 |
+
int offset = 1 << STEP;
|
162 |
+
T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
163 |
+
inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
|
164 |
+
);
|
165 |
+
// Perform scan op if from a valid peer
|
166 |
+
inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
|
167 |
+
? inclusive_output : scan_op(temp, inclusive_output);
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
/// Exclusive scan
|
172 |
+
// Get exclusive from inclusive
|
173 |
+
template <typename ScanOpT>
|
174 |
+
__device__ __forceinline__ void ExclusiveReverseScan(
|
175 |
+
T input, ///< [in] Calling thread's input item.
|
176 |
+
T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
|
177 |
+
ScanOpT scan_op, ///< [in] Binary scan operator
|
178 |
+
T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
|
179 |
+
{
|
180 |
+
T inclusive_output;
|
181 |
+
InclusiveReverseScan(input, inclusive_output, scan_op);
|
182 |
+
warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
|
183 |
+
// initial value unknown
|
184 |
+
exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
185 |
+
inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
|
186 |
+
);
|
187 |
+
}
|
188 |
+
|
189 |
+
/**
|
190 |
+
* \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
|
191 |
+
*/
|
192 |
+
template <typename ScanOpT>
|
193 |
+
__device__ __forceinline__ void ReverseScan(
|
194 |
+
T input, ///< [in] Calling thread's input item.
|
195 |
+
T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item.
|
196 |
+
T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item.
|
197 |
+
ScanOpT scan_op) ///< [in] Binary scan operator
|
198 |
+
{
|
199 |
+
InclusiveReverseScan(input, inclusive_output, scan_op);
|
200 |
+
// initial value unknown
|
201 |
+
exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
202 |
+
inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
|
203 |
+
);
|
204 |
+
}
|
205 |
+
|
206 |
+
};
|
207 |
+
|
208 |
+
/**
|
209 |
+
* \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
|
210 |
+
*/
|
211 |
+
template <
|
212 |
+
typename T, ///< Data type being scanned
|
213 |
+
int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension
|
214 |
+
bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
|
215 |
+
>
|
216 |
+
struct BlockReverseScan {
|
217 |
+
//---------------------------------------------------------------------
|
218 |
+
// Types and constants
|
219 |
+
//---------------------------------------------------------------------
|
220 |
+
|
221 |
+
/// Constants
|
222 |
+
/// The thread block size in threads
|
223 |
+
static constexpr int BLOCK_THREADS = BLOCK_DIM_X;
|
224 |
+
|
225 |
+
/// Layout type for padded thread block raking grid
|
226 |
+
using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
|
227 |
+
// The number of reduction elements is not a multiple of the number of raking threads for now
|
228 |
+
static_assert(BlockRakingLayout::UNGUARDED);
|
229 |
+
|
230 |
+
/// Number of raking threads
|
231 |
+
static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
|
232 |
+
/// Number of raking elements per warp synchronous raking thread
|
233 |
+
static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
|
234 |
+
/// Cooperative work can be entirely warp synchronous
|
235 |
+
static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));
|
236 |
+
|
237 |
+
/// WarpReverseScan utility type
|
238 |
+
using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;
|
239 |
+
|
240 |
+
/// Shared memory storage layout type
|
241 |
+
struct _TempStorage {
|
242 |
+
typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid
|
243 |
+
};
|
244 |
+
|
245 |
+
|
246 |
+
/// Alias wrapper allowing storage to be unioned
|
247 |
+
struct TempStorage : cub::Uninitialized<_TempStorage> {};
|
248 |
+
|
249 |
+
|
250 |
+
//---------------------------------------------------------------------
|
251 |
+
// Per-thread fields
|
252 |
+
//---------------------------------------------------------------------
|
253 |
+
|
254 |
+
// Thread fields
|
255 |
+
_TempStorage &temp_storage;
|
256 |
+
unsigned int linear_tid;
|
257 |
+
T cached_segment[SEGMENT_LENGTH];
|
258 |
+
|
259 |
+
|
260 |
+
//---------------------------------------------------------------------
|
261 |
+
// Utility methods
|
262 |
+
//---------------------------------------------------------------------
|
263 |
+
|
264 |
+
/// Performs upsweep raking reduction, returning the aggregate
|
265 |
+
template <typename ScanOp>
|
266 |
+
__device__ __forceinline__ T Upsweep(ScanOp scan_op) {
|
267 |
+
T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
|
268 |
+
// Read data into registers
|
269 |
+
#pragma unroll
|
270 |
+
for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
|
271 |
+
T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
|
272 |
+
#pragma unroll
|
273 |
+
for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
|
274 |
+
raking_partial = scan_op(raking_partial, cached_segment[i]);
|
275 |
+
}
|
276 |
+
return raking_partial;
|
277 |
+
}
|
278 |
+
|
279 |
+
|
280 |
+
/// Performs exclusive downsweep raking scan
|
281 |
+
template <typename ScanOp>
|
282 |
+
__device__ __forceinline__ void ExclusiveDownsweep(
|
283 |
+
ScanOp scan_op,
|
284 |
+
T raking_partial)
|
285 |
+
{
|
286 |
+
T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
|
287 |
+
// Read data back into registers
|
288 |
+
if (!MEMOIZE) {
|
289 |
+
#pragma unroll
|
290 |
+
for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
|
291 |
+
}
|
292 |
+
ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
|
293 |
+
// Write data back to smem
|
294 |
+
#pragma unroll
|
295 |
+
for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
|
296 |
+
}
|
297 |
+
|
298 |
+
|
299 |
+
//---------------------------------------------------------------------
|
300 |
+
// Constructors
|
301 |
+
//---------------------------------------------------------------------
|
302 |
+
|
303 |
+
/// Constructor
|
304 |
+
__device__ __forceinline__ BlockReverseScan(
|
305 |
+
TempStorage &temp_storage)
|
306 |
+
:
|
307 |
+
temp_storage(temp_storage.Alias()),
|
308 |
+
linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
|
309 |
+
{}
|
310 |
+
|
311 |
+
|
312 |
+
/// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
|
313 |
+
template <
|
314 |
+
typename ScanOp,
|
315 |
+
typename BlockPostfixCallbackOp>
|
316 |
+
__device__ __forceinline__ void ExclusiveReverseScan(
|
317 |
+
T input, ///< [in] Calling thread's input item
|
318 |
+
T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input)
|
319 |
+
ScanOp scan_op, ///< [in] Binary scan operator
|
320 |
+
BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
|
321 |
+
{
|
322 |
+
if (WARP_SYNCHRONOUS) {
|
323 |
+
// Short-circuit directly to warp-synchronous scan
|
324 |
+
T block_aggregate;
|
325 |
+
WarpReverseScan warp_scan;
|
326 |
+
warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
|
327 |
+
// Obtain warp-wide postfix in lane0, then broadcast to other lanes
|
328 |
+
T block_postfix = block_postfix_callback_op(block_aggregate);
|
329 |
+
block_postfix = warp_scan.Broadcast(block_postfix, 0);
|
330 |
+
exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
|
331 |
+
} else {
|
332 |
+
// Place thread partial into shared memory raking grid
|
333 |
+
T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
|
334 |
+
detail::uninitialized_copy(placement_ptr, input);
|
335 |
+
cub::CTA_SYNC();
|
336 |
+
// Reduce parallelism down to just raking threads
|
337 |
+
if (linear_tid < RAKING_THREADS) {
|
338 |
+
WarpReverseScan warp_scan;
|
339 |
+
// Raking upsweep reduction across shared partials
|
340 |
+
T upsweep_partial = Upsweep(scan_op);
|
341 |
+
// Warp-synchronous scan
|
342 |
+
T exclusive_partial, block_aggregate;
|
343 |
+
warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
|
344 |
+
// Obtain block-wide postfix in lane0, then broadcast to other lanes
|
345 |
+
T block_postfix = block_postfix_callback_op(block_aggregate);
|
346 |
+
block_postfix = warp_scan.Broadcast(block_postfix, 0);
|
347 |
+
// Update postfix with warpscan exclusive partial
|
348 |
+
T downsweep_postfix = linear_tid == RAKING_THREADS - 1
|
349 |
+
? block_postfix : scan_op(block_postfix, exclusive_partial);
|
350 |
+
// Exclusive raking downsweep scan
|
351 |
+
ExclusiveDownsweep(scan_op, downsweep_postfix);
|
352 |
+
}
|
353 |
+
cub::CTA_SYNC();
|
354 |
+
// Grab thread postfix from shared memory
|
355 |
+
exclusive_output = *placement_ptr;
|
356 |
+
|
357 |
+
// // Compute warp scan in each warp.
|
358 |
+
// // The exclusive output from the last lane in each warp is invalid.
|
359 |
+
// T inclusive_output;
|
360 |
+
// WarpReverseScan warp_scan;
|
361 |
+
// warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);
|
362 |
+
|
363 |
+
// // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid.
|
364 |
+
// T block_aggregate;
|
365 |
+
// T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);
|
366 |
+
|
367 |
+
// // Apply warp postfix to our lane's partial
|
368 |
+
// if (warp_id != 0) {
|
369 |
+
// exclusive_output = scan_op(warp_postfix, exclusive_output);
|
370 |
+
// if (lane_id == 0) { exclusive_output = warp_postfix; }
|
371 |
+
// }
|
372 |
+
|
373 |
+
// // Use the first warp to determine the thread block postfix, returning the result in lane0
|
374 |
+
// if (warp_id == 0) {
|
375 |
+
// T block_postfix = block_postfix_callback_op(block_aggregate);
|
376 |
+
// if (lane_id == 0) {
|
377 |
+
// // Share the postfix with all threads
|
378 |
+
// detail::uninitialized_copy(&temp_storage.block_postfix,
|
379 |
+
// block_postfix);
|
380 |
+
|
381 |
+
// exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
|
382 |
+
// }
|
383 |
+
// }
|
384 |
+
|
385 |
+
// cub::CTA_SYNC();
|
386 |
+
|
387 |
+
// // Incorporate thread block postfix into outputs
|
388 |
+
// T block_postfix = temp_storage.block_postfix;
|
389 |
+
// if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
|
390 |
+
}
|
391 |
+
}
|
392 |
+
|
393 |
+
|
394 |
+
/**
|
395 |
+
* \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
|
396 |
+
*/
|
397 |
+
template <
|
398 |
+
int ITEMS_PER_THREAD,
|
399 |
+
typename ScanOp,
|
400 |
+
typename BlockPostfixCallbackOp>
|
401 |
+
__device__ __forceinline__ void InclusiveReverseScan(
|
402 |
+
T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
|
403 |
+
T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input)
|
404 |
+
ScanOp scan_op, ///< [in] Binary scan functor
|
405 |
+
BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
|
406 |
+
{
|
407 |
+
// Reduce consecutive thread items in registers
|
408 |
+
T thread_postfix = ThreadReverseReduce(input, scan_op);
|
409 |
+
// Exclusive thread block-scan
|
410 |
+
ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
|
411 |
+
// Inclusive scan in registers with postfix as seed
|
412 |
+
ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
|
413 |
+
}
|
414 |
+
|
415 |
+
};
|
mamba/csrc/selective_scan/selective_scan.cpp
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <c10/cuda/CUDAGuard.h>
|
7 |
+
#include <torch/extension.h>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
#include "selective_scan.h"
|
11 |
+
|
12 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
13 |
+
|
14 |
+
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
15 |
+
if (ITYPE == at::ScalarType::Half) { \
|
16 |
+
using input_t = at::Half; \
|
17 |
+
__VA_ARGS__(); \
|
18 |
+
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
19 |
+
using input_t = at::BFloat16; \
|
20 |
+
__VA_ARGS__(); \
|
21 |
+
} else if (ITYPE == at::ScalarType::Float) { \
|
22 |
+
using input_t = float; \
|
23 |
+
__VA_ARGS__(); \
|
24 |
+
} else { \
|
25 |
+
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
26 |
+
}
|
27 |
+
|
28 |
+
#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
|
29 |
+
if (WTYPE == at::ScalarType::Half) { \
|
30 |
+
using weight_t = at::Half; \
|
31 |
+
__VA_ARGS__(); \
|
32 |
+
} else if (WTYPE == at::ScalarType::BFloat16) { \
|
33 |
+
using weight_t = at::BFloat16; \
|
34 |
+
__VA_ARGS__(); \
|
35 |
+
} else if (WTYPE == at::ScalarType::Float) { \
|
36 |
+
using weight_t = float; \
|
37 |
+
__VA_ARGS__(); \
|
38 |
+
} else { \
|
39 |
+
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
40 |
+
}
|
41 |
+
|
42 |
+
#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \
|
43 |
+
if (WTYPE == at::ScalarType::Float) { \
|
44 |
+
using weight_t = float; \
|
45 |
+
__VA_ARGS__(); \
|
46 |
+
} else if (WTYPE == at::ScalarType::ComplexFloat) { \
|
47 |
+
using weight_t = c10::complex<float>; \
|
48 |
+
__VA_ARGS__(); \
|
49 |
+
} else { \
|
50 |
+
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
51 |
+
}
|
52 |
+
|
53 |
+
template<typename input_t, typename weight_t>
|
54 |
+
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
55 |
+
|
56 |
+
template <typename input_t, typename weight_t>
|
57 |
+
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
|
58 |
+
|
59 |
+
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
60 |
+
// sizes
|
61 |
+
const size_t batch,
|
62 |
+
const size_t dim,
|
63 |
+
const size_t seqlen,
|
64 |
+
const size_t dstate,
|
65 |
+
const size_t n_groups,
|
66 |
+
const size_t n_chunks,
|
67 |
+
const bool is_variable_B,
|
68 |
+
const bool is_variable_C,
|
69 |
+
// device pointers
|
70 |
+
const at::Tensor u,
|
71 |
+
const at::Tensor delta,
|
72 |
+
const at::Tensor A,
|
73 |
+
const at::Tensor B,
|
74 |
+
const at::Tensor C,
|
75 |
+
const at::Tensor out,
|
76 |
+
const at::Tensor z,
|
77 |
+
const at::Tensor out_z,
|
78 |
+
void* D_ptr,
|
79 |
+
void* delta_bias_ptr,
|
80 |
+
void* x_ptr,
|
81 |
+
bool has_z,
|
82 |
+
bool delta_softplus) {
|
83 |
+
|
84 |
+
// Reset the parameters
|
85 |
+
memset(¶ms, 0, sizeof(params));
|
86 |
+
|
87 |
+
params.batch = batch;
|
88 |
+
params.dim = dim;
|
89 |
+
params.seqlen = seqlen;
|
90 |
+
params.dstate = dstate;
|
91 |
+
params.n_groups = n_groups;
|
92 |
+
params.n_chunks = n_chunks;
|
93 |
+
params.dim_ngroups_ratio = dim / n_groups;
|
94 |
+
|
95 |
+
params.delta_softplus = delta_softplus;
|
96 |
+
|
97 |
+
params.is_variable_B = is_variable_B;
|
98 |
+
params.is_variable_C = is_variable_C;
|
99 |
+
|
100 |
+
// Set the pointers and strides.
|
101 |
+
params.u_ptr = u.data_ptr();
|
102 |
+
params.delta_ptr = delta.data_ptr();
|
103 |
+
params.A_ptr = A.data_ptr();
|
104 |
+
params.B_ptr = B.data_ptr();
|
105 |
+
params.C_ptr = C.data_ptr();
|
106 |
+
params.D_ptr = D_ptr;
|
107 |
+
params.delta_bias_ptr = delta_bias_ptr;
|
108 |
+
params.out_ptr = out.data_ptr();
|
109 |
+
params.x_ptr = x_ptr;
|
110 |
+
params.z_ptr = has_z ? z.data_ptr() : nullptr;
|
111 |
+
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
|
112 |
+
// All stride are in elements, not bytes.
|
113 |
+
params.A_d_stride = A.stride(0);
|
114 |
+
params.A_dstate_stride = A.stride(1);
|
115 |
+
if (!is_variable_B) {
|
116 |
+
params.B_d_stride = B.stride(0);
|
117 |
+
} else {
|
118 |
+
params.B_batch_stride = B.stride(0);
|
119 |
+
params.B_group_stride = B.stride(1);
|
120 |
+
}
|
121 |
+
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
|
122 |
+
if (!is_variable_C) {
|
123 |
+
params.C_d_stride = C.stride(0);
|
124 |
+
} else {
|
125 |
+
params.C_batch_stride = C.stride(0);
|
126 |
+
params.C_group_stride = C.stride(1);
|
127 |
+
}
|
128 |
+
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
|
129 |
+
params.u_batch_stride = u.stride(0);
|
130 |
+
params.u_d_stride = u.stride(1);
|
131 |
+
params.delta_batch_stride = delta.stride(0);
|
132 |
+
params.delta_d_stride = delta.stride(1);
|
133 |
+
if (has_z) {
|
134 |
+
params.z_batch_stride = z.stride(0);
|
135 |
+
params.z_d_stride = z.stride(1);
|
136 |
+
params.out_z_batch_stride = out_z.stride(0);
|
137 |
+
params.out_z_d_stride = out_z.stride(1);
|
138 |
+
}
|
139 |
+
params.out_batch_stride = out.stride(0);
|
140 |
+
params.out_d_stride = out.stride(1);
|
141 |
+
}
|
142 |
+
|
143 |
+
void set_ssm_params_bwd(SSMParamsBwd ¶ms,
|
144 |
+
// sizes
|
145 |
+
const size_t batch,
|
146 |
+
const size_t dim,
|
147 |
+
const size_t seqlen,
|
148 |
+
const size_t dstate,
|
149 |
+
const size_t n_groups,
|
150 |
+
const size_t n_chunks,
|
151 |
+
const bool is_variable_B,
|
152 |
+
const bool is_variable_C,
|
153 |
+
// device pointers
|
154 |
+
const at::Tensor u,
|
155 |
+
const at::Tensor delta,
|
156 |
+
const at::Tensor A,
|
157 |
+
const at::Tensor B,
|
158 |
+
const at::Tensor C,
|
159 |
+
const at::Tensor z,
|
160 |
+
const at::Tensor out,
|
161 |
+
const at::Tensor out_z,
|
162 |
+
void* D_ptr,
|
163 |
+
void* delta_bias_ptr,
|
164 |
+
void* x_ptr,
|
165 |
+
const at::Tensor dout,
|
166 |
+
const at::Tensor du,
|
167 |
+
const at::Tensor ddelta,
|
168 |
+
const at::Tensor dA,
|
169 |
+
const at::Tensor dB,
|
170 |
+
const at::Tensor dC,
|
171 |
+
const at::Tensor dz,
|
172 |
+
void* dD_ptr,
|
173 |
+
void* ddelta_bias_ptr,
|
174 |
+
bool has_z,
|
175 |
+
bool delta_softplus,
|
176 |
+
bool recompute_out_z) {
|
177 |
+
// Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
|
178 |
+
set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
179 |
+
u, delta, A, B, C, has_z ? out : dout,
|
180 |
+
has_z ? z : dout,
|
181 |
+
// If not recompute_out_z, pass dout instead of out_z.
|
182 |
+
// This won't be used by the bwd kernel
|
183 |
+
recompute_out_z ? out_z : dout,
|
184 |
+
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
|
185 |
+
if (!recompute_out_z) { params.out_z_ptr = nullptr; }
|
186 |
+
|
187 |
+
// Set the pointers and strides.
|
188 |
+
params.dout_ptr = dout.data_ptr();
|
189 |
+
params.du_ptr = du.data_ptr();
|
190 |
+
params.dA_ptr = dA.data_ptr();
|
191 |
+
params.dB_ptr = dB.data_ptr();
|
192 |
+
params.dC_ptr = dC.data_ptr();
|
193 |
+
params.dD_ptr = dD_ptr;
|
194 |
+
params.ddelta_ptr = ddelta.data_ptr();
|
195 |
+
params.ddelta_bias_ptr = ddelta_bias_ptr;
|
196 |
+
params.dz_ptr = has_z ? dz.data_ptr() : nullptr;
|
197 |
+
// All stride are in elements, not bytes.
|
198 |
+
params.dout_batch_stride = dout.stride(0);
|
199 |
+
params.dout_d_stride = dout.stride(1);
|
200 |
+
params.dA_d_stride = dA.stride(0);
|
201 |
+
params.dA_dstate_stride = dA.stride(1);
|
202 |
+
if (!is_variable_B) {
|
203 |
+
params.dB_d_stride = dB.stride(0);
|
204 |
+
} else {
|
205 |
+
params.dB_batch_stride = dB.stride(0);
|
206 |
+
params.dB_group_stride = dB.stride(1);
|
207 |
+
}
|
208 |
+
params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2);
|
209 |
+
if (!is_variable_C) {
|
210 |
+
params.dC_d_stride = dC.stride(0);
|
211 |
+
} else {
|
212 |
+
params.dC_batch_stride = dC.stride(0);
|
213 |
+
params.dC_group_stride = dC.stride(1);
|
214 |
+
}
|
215 |
+
params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2);
|
216 |
+
params.du_batch_stride = du.stride(0);
|
217 |
+
params.du_d_stride = du.stride(1);
|
218 |
+
params.ddelta_batch_stride = ddelta.stride(0);
|
219 |
+
params.ddelta_d_stride = ddelta.stride(1);
|
220 |
+
if (has_z) {
|
221 |
+
params.dz_batch_stride = dz.stride(0);
|
222 |
+
params.dz_d_stride = dz.stride(1);
|
223 |
+
}
|
224 |
+
}
|
225 |
+
|
226 |
+
std::vector<at::Tensor>
|
227 |
+
selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
|
228 |
+
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
229 |
+
const c10::optional<at::Tensor> &D_,
|
230 |
+
const c10::optional<at::Tensor> &z_,
|
231 |
+
const c10::optional<at::Tensor> &delta_bias_,
|
232 |
+
bool delta_softplus) {
|
233 |
+
auto input_type = u.scalar_type();
|
234 |
+
auto weight_type = A.scalar_type();
|
235 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
236 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
|
237 |
+
|
238 |
+
const bool is_variable_B = B.dim() >= 3;
|
239 |
+
const bool is_variable_C = C.dim() >= 3;
|
240 |
+
const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
|
241 |
+
|
242 |
+
TORCH_CHECK(delta.scalar_type() == input_type);
|
243 |
+
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
244 |
+
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
245 |
+
|
246 |
+
TORCH_CHECK(u.is_cuda());
|
247 |
+
TORCH_CHECK(delta.is_cuda());
|
248 |
+
TORCH_CHECK(A.is_cuda());
|
249 |
+
TORCH_CHECK(B.is_cuda());
|
250 |
+
TORCH_CHECK(C.is_cuda());
|
251 |
+
|
252 |
+
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
253 |
+
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
254 |
+
|
255 |
+
const auto sizes = u.sizes();
|
256 |
+
const int batch_size = sizes[0];
|
257 |
+
const int dim = sizes[1];
|
258 |
+
const int seqlen = sizes[2];
|
259 |
+
const int dstate = A.size(1);
|
260 |
+
const int n_groups = is_variable_B ? B.size(1) : 1;
|
261 |
+
|
262 |
+
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
263 |
+
|
264 |
+
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
265 |
+
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
266 |
+
CHECK_SHAPE(A, dim, dstate);
|
267 |
+
if (!is_variable_B) {
|
268 |
+
CHECK_SHAPE(B, dim, dstate);
|
269 |
+
} else {
|
270 |
+
CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
|
271 |
+
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
272 |
+
}
|
273 |
+
if (!is_variable_C) {
|
274 |
+
CHECK_SHAPE(C, dim, dstate);
|
275 |
+
} else {
|
276 |
+
CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
|
277 |
+
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
278 |
+
}
|
279 |
+
|
280 |
+
if (D_.has_value()) {
|
281 |
+
auto D = D_.value();
|
282 |
+
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
283 |
+
TORCH_CHECK(D.is_cuda());
|
284 |
+
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
285 |
+
CHECK_SHAPE(D, dim);
|
286 |
+
}
|
287 |
+
|
288 |
+
if (delta_bias_.has_value()) {
|
289 |
+
auto delta_bias = delta_bias_.value();
|
290 |
+
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
291 |
+
TORCH_CHECK(delta_bias.is_cuda());
|
292 |
+
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
293 |
+
CHECK_SHAPE(delta_bias, dim);
|
294 |
+
}
|
295 |
+
|
296 |
+
at::Tensor z, out_z;
|
297 |
+
const bool has_z = z_.has_value();
|
298 |
+
if (has_z) {
|
299 |
+
z = z_.value();
|
300 |
+
TORCH_CHECK(z.scalar_type() == input_type);
|
301 |
+
TORCH_CHECK(z.is_cuda());
|
302 |
+
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
303 |
+
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
304 |
+
out_z = torch::empty_like(z);
|
305 |
+
}
|
306 |
+
|
307 |
+
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
308 |
+
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
309 |
+
// at::Tensor out = torch::empty_like(u);
|
310 |
+
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
311 |
+
at::Tensor out = torch::empty_like(delta);
|
312 |
+
at::Tensor x;
|
313 |
+
x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
|
314 |
+
|
315 |
+
SSMParamsBase params;
|
316 |
+
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
317 |
+
u, delta, A, B, C, out, z, out_z,
|
318 |
+
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
319 |
+
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
320 |
+
x.data_ptr(),
|
321 |
+
has_z,
|
322 |
+
delta_softplus);
|
323 |
+
|
324 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
325 |
+
// Cast to char to avoid compiler warning about narrowing
|
326 |
+
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
327 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
328 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
329 |
+
DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] {
|
330 |
+
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
331 |
+
});
|
332 |
+
});
|
333 |
+
std::vector<at::Tensor> result = {out, x};
|
334 |
+
if (has_z) { result.push_back(out_z); }
|
335 |
+
return result;
|
336 |
+
}
|
337 |
+
|
338 |
+
std::vector<at::Tensor>
|
339 |
+
selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
|
340 |
+
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
341 |
+
const c10::optional<at::Tensor> &D_,
|
342 |
+
const c10::optional<at::Tensor> &z_,
|
343 |
+
const c10::optional<at::Tensor> &delta_bias_,
|
344 |
+
const at::Tensor &dout,
|
345 |
+
const c10::optional<at::Tensor> &x_,
|
346 |
+
const c10::optional<at::Tensor> &out_,
|
347 |
+
c10::optional<at::Tensor> &dz_,
|
348 |
+
bool delta_softplus,
|
349 |
+
bool recompute_out_z) {
|
350 |
+
auto input_type = u.scalar_type();
|
351 |
+
auto weight_type = A.scalar_type();
|
352 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
353 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
|
354 |
+
|
355 |
+
const bool is_variable_B = B.dim() >= 3;
|
356 |
+
const bool is_variable_C = C.dim() >= 3;
|
357 |
+
const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
|
358 |
+
|
359 |
+
TORCH_CHECK(delta.scalar_type() == input_type);
|
360 |
+
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
361 |
+
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
362 |
+
TORCH_CHECK(dout.scalar_type() == input_type);
|
363 |
+
|
364 |
+
TORCH_CHECK(u.is_cuda());
|
365 |
+
TORCH_CHECK(delta.is_cuda());
|
366 |
+
TORCH_CHECK(A.is_cuda());
|
367 |
+
TORCH_CHECK(B.is_cuda());
|
368 |
+
TORCH_CHECK(C.is_cuda());
|
369 |
+
TORCH_CHECK(dout.is_cuda());
|
370 |
+
|
371 |
+
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
372 |
+
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
373 |
+
TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
|
374 |
+
|
375 |
+
const auto sizes = u.sizes();
|
376 |
+
const int batch_size = sizes[0];
|
377 |
+
const int dim = sizes[1];
|
378 |
+
const int seqlen = sizes[2];
|
379 |
+
const int dstate = A.size(1);
|
380 |
+
const int n_groups = is_variable_B ? B.size(1) : 1;
|
381 |
+
|
382 |
+
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
383 |
+
|
384 |
+
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
385 |
+
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
386 |
+
CHECK_SHAPE(A, dim, dstate);
|
387 |
+
if (!is_variable_B) {
|
388 |
+
CHECK_SHAPE(B, dim, dstate);
|
389 |
+
} else {
|
390 |
+
CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
|
391 |
+
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
392 |
+
}
|
393 |
+
if (!is_variable_C) {
|
394 |
+
CHECK_SHAPE(C, dim, dstate);
|
395 |
+
} else {
|
396 |
+
CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
|
397 |
+
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
398 |
+
}
|
399 |
+
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
400 |
+
|
401 |
+
if (D_.has_value()) {
|
402 |
+
auto D = D_.value();
|
403 |
+
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
404 |
+
TORCH_CHECK(D.is_cuda());
|
405 |
+
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
406 |
+
CHECK_SHAPE(D, dim);
|
407 |
+
}
|
408 |
+
|
409 |
+
if (delta_bias_.has_value()) {
|
410 |
+
auto delta_bias = delta_bias_.value();
|
411 |
+
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
412 |
+
TORCH_CHECK(delta_bias.is_cuda());
|
413 |
+
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
414 |
+
CHECK_SHAPE(delta_bias, dim);
|
415 |
+
}
|
416 |
+
|
417 |
+
at::Tensor z, out, dz, out_z;
|
418 |
+
const bool has_z = z_.has_value();
|
419 |
+
if (has_z) {
|
420 |
+
z = z_.value();
|
421 |
+
TORCH_CHECK(z.scalar_type() == input_type);
|
422 |
+
TORCH_CHECK(z.is_cuda());
|
423 |
+
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
424 |
+
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
425 |
+
|
426 |
+
TORCH_CHECK(out_.has_value());
|
427 |
+
out = out_.value();
|
428 |
+
TORCH_CHECK(out.scalar_type() == input_type);
|
429 |
+
TORCH_CHECK(out.is_cuda());
|
430 |
+
TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1);
|
431 |
+
CHECK_SHAPE(out, batch_size, dim, seqlen);
|
432 |
+
|
433 |
+
if (dz_.has_value()) {
|
434 |
+
dz = dz_.value();
|
435 |
+
TORCH_CHECK(dz.scalar_type() == input_type);
|
436 |
+
TORCH_CHECK(dz.is_cuda());
|
437 |
+
TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1);
|
438 |
+
CHECK_SHAPE(dz, batch_size, dim, seqlen);
|
439 |
+
} else {
|
440 |
+
dz = torch::empty_like(z);
|
441 |
+
}
|
442 |
+
if (recompute_out_z) {
|
443 |
+
out_z = torch::empty_like(out);
|
444 |
+
}
|
445 |
+
}
|
446 |
+
|
447 |
+
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
448 |
+
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
449 |
+
if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
|
450 |
+
if (x_.has_value()) {
|
451 |
+
auto x = x_.value();
|
452 |
+
TORCH_CHECK(x.scalar_type() == weight_type);
|
453 |
+
TORCH_CHECK(x.is_cuda());
|
454 |
+
TORCH_CHECK(x.is_contiguous());
|
455 |
+
CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
|
456 |
+
}
|
457 |
+
|
458 |
+
at::Tensor du = torch::empty_like(u);
|
459 |
+
at::Tensor ddelta = torch::empty_like(delta);
|
460 |
+
at::Tensor dA = torch::zeros_like(A);
|
461 |
+
at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32));
|
462 |
+
at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32));
|
463 |
+
at::Tensor dD;
|
464 |
+
if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
|
465 |
+
at::Tensor ddelta_bias;
|
466 |
+
if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
|
467 |
+
|
468 |
+
SSMParamsBwd params;
|
469 |
+
set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
470 |
+
u, delta, A, B, C, z, out, out_z,
|
471 |
+
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
472 |
+
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
473 |
+
x_.has_value() ? x_.value().data_ptr() : nullptr,
|
474 |
+
dout, du, ddelta, dA, dB, dC, dz,
|
475 |
+
D_.has_value() ? dD.data_ptr() : nullptr,
|
476 |
+
delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
|
477 |
+
has_z, delta_softplus, recompute_out_z);
|
478 |
+
|
479 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
480 |
+
// Cast to char to avoid compiler warning about narrowing
|
481 |
+
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
482 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
483 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
|
484 |
+
DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] {
|
485 |
+
selective_scan_bwd_cuda<input_t, weight_t>(params, stream);
|
486 |
+
});
|
487 |
+
});
|
488 |
+
std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
|
489 |
+
if (has_z) { result.push_back(dz); }
|
490 |
+
if (recompute_out_z) { result.push_back(out_z); }
|
491 |
+
return result;
|
492 |
+
}
|
493 |
+
|
494 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
495 |
+
m.def("fwd", &selective_scan_fwd, "Selective scan forward");
|
496 |
+
m.def("bwd", &selective_scan_bwd, "Selective scan backward");
|
497 |
+
}
|
mamba/csrc/selective_scan/selective_scan.h
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
8 |
+
|
9 |
+
struct SSMScanParamsBase {
|
10 |
+
using index_t = uint32_t;
|
11 |
+
|
12 |
+
int batch, seqlen, n_chunks;
|
13 |
+
index_t a_batch_stride;
|
14 |
+
index_t b_batch_stride;
|
15 |
+
index_t out_batch_stride;
|
16 |
+
|
17 |
+
// Common data pointers.
|
18 |
+
void *__restrict__ a_ptr;
|
19 |
+
void *__restrict__ b_ptr;
|
20 |
+
void *__restrict__ out_ptr;
|
21 |
+
void *__restrict__ x_ptr;
|
22 |
+
};
|
23 |
+
|
24 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
25 |
+
|
26 |
+
struct SSMParamsBase {
|
27 |
+
using index_t = uint32_t;
|
28 |
+
|
29 |
+
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
30 |
+
int dim_ngroups_ratio;
|
31 |
+
bool is_variable_B;
|
32 |
+
bool is_variable_C;
|
33 |
+
|
34 |
+
bool delta_softplus;
|
35 |
+
|
36 |
+
index_t A_d_stride;
|
37 |
+
index_t A_dstate_stride;
|
38 |
+
index_t B_batch_stride;
|
39 |
+
index_t B_d_stride;
|
40 |
+
index_t B_dstate_stride;
|
41 |
+
index_t B_group_stride;
|
42 |
+
index_t C_batch_stride;
|
43 |
+
index_t C_d_stride;
|
44 |
+
index_t C_dstate_stride;
|
45 |
+
index_t C_group_stride;
|
46 |
+
index_t u_batch_stride;
|
47 |
+
index_t u_d_stride;
|
48 |
+
index_t delta_batch_stride;
|
49 |
+
index_t delta_d_stride;
|
50 |
+
index_t z_batch_stride;
|
51 |
+
index_t z_d_stride;
|
52 |
+
index_t out_batch_stride;
|
53 |
+
index_t out_d_stride;
|
54 |
+
index_t out_z_batch_stride;
|
55 |
+
index_t out_z_d_stride;
|
56 |
+
|
57 |
+
// Common data pointers.
|
58 |
+
void *__restrict__ A_ptr;
|
59 |
+
void *__restrict__ B_ptr;
|
60 |
+
void *__restrict__ C_ptr;
|
61 |
+
void *__restrict__ D_ptr;
|
62 |
+
void *__restrict__ u_ptr;
|
63 |
+
void *__restrict__ delta_ptr;
|
64 |
+
void *__restrict__ delta_bias_ptr;
|
65 |
+
void *__restrict__ out_ptr;
|
66 |
+
void *__restrict__ x_ptr;
|
67 |
+
void *__restrict__ z_ptr;
|
68 |
+
void *__restrict__ out_z_ptr;
|
69 |
+
};
|
70 |
+
|
71 |
+
struct SSMParamsBwd: public SSMParamsBase {
|
72 |
+
index_t dout_batch_stride;
|
73 |
+
index_t dout_d_stride;
|
74 |
+
index_t dA_d_stride;
|
75 |
+
index_t dA_dstate_stride;
|
76 |
+
index_t dB_batch_stride;
|
77 |
+
index_t dB_group_stride;
|
78 |
+
index_t dB_d_stride;
|
79 |
+
index_t dB_dstate_stride;
|
80 |
+
index_t dC_batch_stride;
|
81 |
+
index_t dC_group_stride;
|
82 |
+
index_t dC_d_stride;
|
83 |
+
index_t dC_dstate_stride;
|
84 |
+
index_t du_batch_stride;
|
85 |
+
index_t du_d_stride;
|
86 |
+
index_t dz_batch_stride;
|
87 |
+
index_t dz_d_stride;
|
88 |
+
index_t ddelta_batch_stride;
|
89 |
+
index_t ddelta_d_stride;
|
90 |
+
|
91 |
+
// Common data pointers.
|
92 |
+
void *__restrict__ dout_ptr;
|
93 |
+
void *__restrict__ dA_ptr;
|
94 |
+
void *__restrict__ dB_ptr;
|
95 |
+
void *__restrict__ dC_ptr;
|
96 |
+
void *__restrict__ dD_ptr;
|
97 |
+
void *__restrict__ du_ptr;
|
98 |
+
void *__restrict__ dz_ptr;
|
99 |
+
void *__restrict__ ddelta_ptr;
|
100 |
+
void *__restrict__ ddelta_bias_ptr;
|
101 |
+
};
|
mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_bwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_bwd_cuda<at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_bwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_bwd_cuda<at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_bwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_bwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_bwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_bwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_bwd_cuda<float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|