Somunia commited on
Commit
306b4ac
1 Parent(s): 8b19012

Upload 116 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. mamba/.github/workflows/publish.yaml +209 -0
  3. mamba/.gitignore +6 -0
  4. mamba/.gitmodules +3 -0
  5. mamba/AUTHORS +2 -0
  6. mamba/LICENSE +201 -0
  7. mamba/README.md +243 -0
  8. mamba/assets/selection.png +0 -0
  9. mamba/assets/ssd_algorithm.png +3 -0
  10. mamba/benchmarks/benchmark_generation_mamba_simple.py +92 -0
  11. mamba/build/lib/mamba_ssm/__init__.py +6 -0
  12. mamba/build/lib/mamba_ssm/distributed/__init__.py +0 -0
  13. mamba/build/lib/mamba_ssm/distributed/distributed_utils.py +144 -0
  14. mamba/build/lib/mamba_ssm/distributed/tensor_parallel.py +296 -0
  15. mamba/build/lib/mamba_ssm/models/__init__.py +0 -0
  16. mamba/build/lib/mamba_ssm/models/config_mamba.py +18 -0
  17. mamba/build/lib/mamba_ssm/models/mixer_seq_simple.py +309 -0
  18. mamba/build/lib/mamba_ssm/modules/__init__.py +0 -0
  19. mamba/build/lib/mamba_ssm/modules/block.py +91 -0
  20. mamba/build/lib/mamba_ssm/modules/mamba2.py +383 -0
  21. mamba/build/lib/mamba_ssm/modules/mamba2_simple.py +200 -0
  22. mamba/build/lib/mamba_ssm/modules/mamba_simple.py +294 -0
  23. mamba/build/lib/mamba_ssm/modules/mha.py +294 -0
  24. mamba/build/lib/mamba_ssm/modules/mlp.py +34 -0
  25. mamba/build/lib/mamba_ssm/modules/ssd_minimal.py +103 -0
  26. mamba/build/lib/mamba_ssm/ops/__init__.py +0 -0
  27. mamba/build/lib/mamba_ssm/ops/selective_scan_interface.py +357 -0
  28. mamba/build/lib/mamba_ssm/ops/triton/__init__.py +0 -0
  29. mamba/build/lib/mamba_ssm/ops/triton/k_activations.py +169 -0
  30. mamba/build/lib/mamba_ssm/ops/triton/layer_norm.py +1113 -0
  31. mamba/build/lib/mamba_ssm/ops/triton/layernorm_gated.py +437 -0
  32. mamba/build/lib/mamba_ssm/ops/triton/selective_state_update.py +265 -0
  33. mamba/build/lib/mamba_ssm/ops/triton/softplus.py +17 -0
  34. mamba/build/lib/mamba_ssm/ops/triton/ssd_bmm.py +262 -0
  35. mamba/build/lib/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
  36. mamba/build/lib/mamba_ssm/ops/triton/ssd_chunk_state.py +988 -0
  37. mamba/build/lib/mamba_ssm/ops/triton/ssd_combined.py +981 -0
  38. mamba/build/lib/mamba_ssm/ops/triton/ssd_state_passing.py +348 -0
  39. mamba/build/lib/mamba_ssm/utils/__init__.py +0 -0
  40. mamba/build/lib/mamba_ssm/utils/generation.py +387 -0
  41. mamba/build/lib/mamba_ssm/utils/hf.py +23 -0
  42. mamba/csrc/selective_scan/reverse_scan.cuh +415 -0
  43. mamba/csrc/selective_scan/selective_scan.cpp +497 -0
  44. mamba/csrc/selective_scan/selective_scan.h +101 -0
  45. mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu +9 -0
  46. mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu +9 -0
  47. mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu +9 -0
  48. mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu +9 -0
  49. mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu +9 -0
  50. mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu +9 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ mamba/assets/ssd_algorithm.png filter=lfs diff=lfs merge=lfs -text
mamba/.github/workflows/publish.yaml ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow will:
2
+ # - Create a new Github release
3
+ # - Build wheels for supported architectures
4
+ # - Deploy the wheels to the Github release
5
+ # - Release the static code to PyPi
6
+ # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
7
+
8
+ name: Build wheels and deploy
9
+
10
+ on:
11
+ create:
12
+ tags:
13
+ - v*
14
+
15
+ jobs:
16
+
17
+ setup_release:
18
+ name: Create Release
19
+ runs-on: ubuntu-latest
20
+ steps:
21
+ - name: Get the tag version
22
+ id: extract_branch
23
+ run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
24
+ shell: bash
25
+
26
+ - name: Create Release
27
+ id: create_release
28
+ uses: actions/create-release@v1
29
+ env:
30
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
31
+ with:
32
+ tag_name: ${{ steps.extract_branch.outputs.branch }}
33
+ release_name: ${{ steps.extract_branch.outputs.branch }}
34
+
35
+ build_wheels:
36
+ name: Build Wheel
37
+ needs: setup_release
38
+ runs-on: ${{ matrix.os }}
39
+
40
+ strategy:
41
+ fail-fast: false
42
+ matrix:
43
+ # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
44
+ # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
45
+ os: [ubuntu-20.04]
46
+ python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
47
+ torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0']
48
+ cuda-version: ['11.8.0', '12.2.2']
49
+ # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
50
+ # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
51
+ # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
52
+ # when building without C++11 ABI and using it on nvcr images.
53
+ cxx11_abi: ['FALSE', 'TRUE']
54
+ exclude:
55
+ # Pytorch < 2.2 does not support Python 3.12
56
+ - torch-version: '2.0.1'
57
+ python-version: '3.12'
58
+ - torch-version: '2.1.2'
59
+ python-version: '3.12'
60
+ # Pytorch <= 2.0 only supports CUDA <= 11.8
61
+ - torch-version: '2.0.1'
62
+ cuda-version: '12.2.2'
63
+
64
+ steps:
65
+ - name: Checkout
66
+ uses: actions/checkout@v3
67
+
68
+ - name: Set up Python
69
+ uses: actions/setup-python@v4
70
+ with:
71
+ python-version: ${{ matrix.python-version }}
72
+
73
+ - name: Set CUDA and PyTorch versions
74
+ run: |
75
+ echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
76
+ echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
77
+
78
+ - name: Free up disk space
79
+ if: ${{ runner.os == 'Linux' }}
80
+ # https://github.com/easimon/maximize-build-space/blob/master/action.yml
81
+ # https://github.com/easimon/maximize-build-space/tree/test-report
82
+ run: |
83
+ sudo rm -rf /usr/share/dotnet
84
+ sudo rm -rf /opt/ghc
85
+ sudo rm -rf /opt/hostedtoolcache/CodeQL
86
+
87
+ - name: Set up swap space
88
+ if: runner.os == 'Linux'
89
+ uses: pierotofy/[email protected]
90
+ with:
91
+ swap-size-gb: 10
92
+
93
+ - name: Install CUDA ${{ matrix.cuda-version }}
94
+ if: ${{ matrix.cuda-version != 'cpu' }}
95
+ uses: Jimver/[email protected]
96
+ id: cuda-toolkit
97
+ with:
98
+ cuda: ${{ matrix.cuda-version }}
99
+ linux-local-args: '["--toolkit"]'
100
+ # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
101
+ # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
102
+ method: 'network'
103
+ # We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
104
+ # not just nvcc
105
+ # sub-packages: '["nvcc"]'
106
+
107
+ - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
108
+ run: |
109
+ pip install --upgrade pip
110
+ # If we don't install before installing Pytorch, we get error for torch 2.0.1
111
+ # ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
112
+ pip install lit
113
+ # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
114
+ pip install setuptools
115
+ # We want to figure out the CUDA version to download pytorch
116
+ # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
117
+ # This code is ugly, maybe there's a better way to do this.
118
+ export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
119
+ minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
120
+ maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124}[env['MATRIX_TORCH_VERSION']]; \
121
+ print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
122
+ )
123
+ if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
124
+ pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
125
+ else
126
+ pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
127
+ fi
128
+ nvcc --version
129
+ python --version
130
+ python -c "import torch; print('PyTorch:', torch.__version__)"
131
+ python -c "import torch; print('CUDA:', torch.version.cuda)"
132
+ python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
133
+ shell:
134
+ bash
135
+
136
+ - name: Build wheel
137
+ run: |
138
+ # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
139
+ # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
140
+ # However this still fails so I'm using a newer version of setuptools
141
+ pip install setuptools==68.0.0
142
+ pip install ninja packaging wheel
143
+ export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
144
+ export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
145
+ # Limit MAX_JOBS otherwise the github runner goes OOM
146
+ MAX_JOBS=2 MAMBA_FORCE_BUILD="TRUE" MAMBA_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
147
+ tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
148
+ wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
149
+ ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
150
+ echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
151
+
152
+ - name: Log Built Wheels
153
+ run: |
154
+ ls dist
155
+
156
+ - name: Get the tag version
157
+ id: extract_branch
158
+ run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
159
+
160
+ - name: Get Release with tag
161
+ id: get_current_release
162
+ uses: joutvhu/get-release@v1
163
+ with:
164
+ tag_name: ${{ steps.extract_branch.outputs.branch }}
165
+ env:
166
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
167
+
168
+ - name: Upload Release Asset
169
+ id: upload_release_asset
170
+ uses: actions/upload-release-asset@v1
171
+ env:
172
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
173
+ with:
174
+ upload_url: ${{ steps.get_current_release.outputs.upload_url }}
175
+ asset_path: ./dist/${{env.wheel_name}}
176
+ asset_name: ${{env.wheel_name}}
177
+ asset_content_type: application/*
178
+
179
+ publish_package:
180
+ name: Publish package
181
+ needs: [build_wheels]
182
+
183
+ runs-on: ubuntu-latest
184
+
185
+ steps:
186
+ - uses: actions/checkout@v3
187
+
188
+ - uses: actions/setup-python@v4
189
+ with:
190
+ python-version: '3.10'
191
+
192
+ - name: Install dependencies
193
+ run: |
194
+ pip install ninja packaging setuptools wheel twine
195
+ # We don't want to download anything CUDA-related here
196
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
197
+
198
+ - name: Build core package
199
+ env:
200
+ MAMBA_SKIP_CUDA_BUILD: "TRUE"
201
+ run: |
202
+ python setup.py sdist --dist-dir=dist
203
+
204
+ - name: Deploy
205
+ env:
206
+ TWINE_USERNAME: "__token__"
207
+ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
208
+ run: |
209
+ python -m twine upload dist/*
mamba/.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *__pycache__/
2
+ *.egg-info/
3
+ build/
4
+ **.so
5
+ *.hip
6
+ *_hip.*
mamba/.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "3rdparty/lm-evaluation-harness"]
2
+ path = 3rdparty/lm-evaluation-harness
3
+ url = https://github.com/EleutherAI/lm-evaluation-harness/
mamba/AUTHORS ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Tri Dao, [email protected]
2
+ Albert Gu, [email protected]
mamba/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023 Tri Dao, Albert Gu
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
mamba/README.md ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mamba
2
+
3
+ ![Mamba](assets/selection.png "Selective State Space")
4
+ > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
5
+ > Albert Gu*, Tri Dao*\
6
+ > Paper: https://arxiv.org/abs/2312.00752
7
+
8
+ ![Mamba-2](assets/ssd_algorithm.png "State Space Dual Model")
9
+ > **Transformers are SSMs: Generalized Models and Efficient Algorithms**\
10
+ > **Through Structured State Space Duality**\
11
+ > Tri Dao*, Albert Gu*\
12
+ > Paper: https://arxiv.org/abs/2405.21060
13
+
14
+ ## About
15
+
16
+ Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
17
+ It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
18
+ with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
19
+
20
+ ## Installation
21
+
22
+ - [Option] `pip install causal-conv1d>=1.4.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
23
+ - `pip install mamba-ssm`: the core Mamba package.
24
+ - `pip install mamba-ssm[causal-conv1d]`: To install core Mamba package and causal-conv1d.
25
+ - `pip install mamba-ssm[dev]`: To install core Mamba package and dev depdencies.
26
+
27
+ It can also be built from source with `pip install .` from this repository.
28
+
29
+ If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
30
+
31
+ Other requirements:
32
+ - Linux
33
+ - NVIDIA GPU
34
+ - PyTorch 1.12+
35
+ - CUDA 11.6+
36
+
37
+ For AMD cards, see additional prerequisites below.
38
+
39
+ ## Usage
40
+
41
+ We expose several levels of interface with the Mamba model.
42
+
43
+ ### Selective SSM
44
+
45
+ Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
46
+
47
+ Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
48
+
49
+ ### Mamba Block
50
+
51
+ The main module of this repository is the Mamba architecture block wrapping the selective SSM.
52
+
53
+ Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
54
+
55
+ Usage:
56
+ ``` python
57
+ import torch
58
+ from mamba_ssm import Mamba
59
+
60
+ batch, length, dim = 2, 64, 16
61
+ x = torch.randn(batch, length, dim).to("cuda")
62
+ model = Mamba(
63
+ # This module uses roughly 3 * expand * d_model^2 parameters
64
+ d_model=dim, # Model dimension d_model
65
+ d_state=16, # SSM state expansion factor
66
+ d_conv=4, # Local convolution width
67
+ expand=2, # Block expansion factor
68
+ ).to("cuda")
69
+ y = model(x)
70
+ assert y.shape == x.shape
71
+ ```
72
+
73
+ ### Mamba-2
74
+
75
+ The Mamba-2 block is implemented at [modules/mamba2.py](mamba_ssm/modules/mamba2.py).
76
+
77
+ A simpler version is at [modules/mamba2_simple.py](mamba_ssm/modules/mamba2_simple.py)
78
+
79
+ The usage is similar to Mamba(-1):
80
+ ``` python
81
+ from mamba_ssm import Mamba2
82
+ model = Mamba2(
83
+ # This module uses roughly 3 * expand * d_model^2 parameters
84
+ d_model=dim, # Model dimension d_model
85
+ d_state=64, # SSM state expansion factor, typically 64 or 128
86
+ d_conv=4, # Local convolution width
87
+ expand=2, # Block expansion factor
88
+ ).to("cuda")
89
+ y = model(x)
90
+ assert y.shape == x.shape
91
+ ```
92
+
93
+ #### SSD
94
+
95
+ A minimal version of the inner SSD module (Listing 1 from the Mamba-2 paper) with conversion between "discrete" and "continuous" SSM versions
96
+ is at [modules/ssd_minimal.py](mamba_ssm/modules/ssd_minimal.py).
97
+
98
+ ### Mamba Language Model
99
+
100
+ Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
101
+
102
+ Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
103
+
104
+ This is an example of how to integrate Mamba into an end-to-end neural network.
105
+ This example is used in the generation scripts below.
106
+
107
+
108
+ ## Pretrained Models
109
+
110
+ Pretrained models are uploaded to
111
+ [Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
112
+ `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, `mamba2-130m`, `mamba2-370m`,
113
+ `mamba2-780m`, `mamba2-1.3b`, `mamba2-2.7b`, `transformerpp-2.7b`, `mamba2attn-2.7b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`
114
+ (trained on 600B tokens on the SlimPajama dataset).
115
+
116
+
117
+ The models will be autodownloaded by the generation script below.
118
+
119
+ These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
120
+
121
+ | Parameters | Layers | Model dim. |
122
+ |------------|--------|------------|
123
+ | 130M | 24 | 768 |
124
+ | 370M | 48 | 1024 |
125
+ | 790M | 48 | 1536 |
126
+ | 1.4B | 48 | 2048 |
127
+ | 2.8B | 64 | 2560 |
128
+
129
+ (The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
130
+
131
+ Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
132
+ Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
133
+
134
+
135
+ ## Evaluations
136
+
137
+ To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
138
+ we use the
139
+ [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)
140
+ library.
141
+
142
+ 1. Install `lm-evaluation-harness` by `pip install lm-eval==0.4.2`.
143
+ 2. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
144
+ ``` sh
145
+ lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
146
+ python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
147
+ ```
148
+
149
+ To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts:
150
+ ``` sh
151
+ lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
152
+ lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
153
+ ```
154
+
155
+ To run evaluations on Mamba-2 models, simply replace the model names:
156
+ ``` sh
157
+ lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
158
+ lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
159
+ lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
160
+ ```
161
+
162
+ Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
163
+
164
+ ## Inference
165
+
166
+ The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
167
+ 1. autoloads a model from the Hugging Face Hub,
168
+ 2. generates completions of a user-specified prompt,
169
+ 3. benchmarks the inference speed of this generation.
170
+
171
+ Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
172
+
173
+ ### Examples
174
+
175
+ To test generation latency (e.g. batch size = 1) with different sampling strategies:
176
+
177
+ ``` sh
178
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
179
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
180
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
181
+ ```
182
+
183
+ To test generation throughput with random prompts (e.g. large batch size):
184
+ ``` sh
185
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64
186
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64
187
+ ```
188
+
189
+ With Mamba-2, you just need to change the model name:
190
+ ``` sh
191
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
192
+ ```
193
+
194
+
195
+ ## Troubleshooting
196
+
197
+ ### Precision
198
+ Our models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary.
199
+ On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation).
200
+
201
+ We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities,
202
+ as a first step please try a framework storing parameters in fp32 (such as AMP).
203
+
204
+ ### Initialization
205
+ Some parts of the model have initializations inherited from prior work on S4 models.
206
+ For [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection.
207
+ However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero).
208
+ If this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework)
209
+ that is specific to the training framework.
210
+
211
+ ## Additional Prerequisites for AMD cards
212
+
213
+ ### Patching ROCm
214
+
215
+ If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
216
+
217
+ 1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
218
+
219
+ 2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
220
+ ```bash
221
+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
222
+ ```
223
+
224
+
225
+ ## Citation
226
+
227
+ If you use this codebase, or otherwise find our work valuable, please cite Mamba:
228
+ ```
229
+ @article{mamba,
230
+ title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
231
+ author={Gu, Albert and Dao, Tri},
232
+ journal={arXiv preprint arXiv:2312.00752},
233
+ year={2023}
234
+ }
235
+
236
+ @inproceedings{mamba2,
237
+ title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
238
+ author={Dao, Tri and Gu, Albert},
239
+ booktitle={International Conference on Machine Learning (ICML)},
240
+ year={2024}
241
+ }
242
+
243
+ ```
mamba/assets/selection.png ADDED
mamba/assets/ssd_algorithm.png ADDED

Git LFS Details

  • SHA256: 91ab82b330761250c3241e4f16fed54a35081115c26777a4cc087c2f6e47f466
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
mamba/benchmarks/benchmark_generation_mamba_simple.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import argparse
4
+ import time
5
+ import json
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange
11
+
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+
14
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
15
+
16
+
17
+ parser = argparse.ArgumentParser(description="Generation benchmarking")
18
+ parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
19
+ parser.add_argument("--prompt", type=str, default=None)
20
+ parser.add_argument("--promptlen", type=int, default=100)
21
+ parser.add_argument("--genlen", type=int, default=100)
22
+ parser.add_argument("--temperature", type=float, default=1.0)
23
+ parser.add_argument("--topk", type=int, default=1)
24
+ parser.add_argument("--topp", type=float, default=1.0)
25
+ parser.add_argument("--minp", type=float, default=0.0)
26
+ parser.add_argument("--repetition-penalty", type=float, default=1.0)
27
+ parser.add_argument("--batch", type=int, default=1)
28
+ args = parser.parse_args()
29
+
30
+ repeats = 3
31
+ device = "cuda"
32
+ dtype = torch.float16
33
+
34
+ print(f"Loading model {args.model_name}")
35
+ is_mamba = args.model_name.startswith("state-spaces/mamba") or args.model_name.startswith("state-spaces/transformerpp")
36
+ if is_mamba:
37
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
38
+ model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
39
+ else:
40
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
41
+ model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
42
+ model.eval()
43
+ print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
44
+
45
+ torch.random.manual_seed(0)
46
+ if args.prompt is None:
47
+ input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
48
+ attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
49
+ else:
50
+ tokens = tokenizer(args.prompt, return_tensors="pt")
51
+ input_ids = tokens.input_ids.to(device=device)
52
+ attn_mask = tokens.attention_mask.to(device=device)
53
+ max_length = input_ids.shape[1] + args.genlen
54
+
55
+ if is_mamba:
56
+ fn = lambda: model.generate(
57
+ input_ids=input_ids,
58
+ max_length=max_length,
59
+ cg=True,
60
+ return_dict_in_generate=True,
61
+ output_scores=True,
62
+ enable_timing=False,
63
+ temperature=args.temperature,
64
+ top_k=args.topk,
65
+ top_p=args.topp,
66
+ min_p=args.minp,
67
+ repetition_penalty=args.repetition_penalty,
68
+ )
69
+ else:
70
+ fn = lambda: model.generate(
71
+ input_ids=input_ids,
72
+ attention_mask=attn_mask,
73
+ max_length=max_length,
74
+ return_dict_in_generate=True,
75
+ pad_token_id=tokenizer.eos_token_id,
76
+ do_sample=True,
77
+ temperature=args.temperature,
78
+ top_k=args.topk,
79
+ top_p=args.topp,
80
+ repetition_penalty=args.repetition_penalty,
81
+ )
82
+ out = fn()
83
+ if args.prompt is not None:
84
+ print(tokenizer.batch_decode(out.sequences.tolist()))
85
+
86
+ torch.cuda.synchronize()
87
+ start = time.time()
88
+ for _ in range(repeats):
89
+ fn()
90
+ torch.cuda.synchronize()
91
+ print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
92
+ print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
mamba/build/lib/mamba_ssm/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __version__ = "2.2.2"
2
+
3
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
+ from mamba_ssm.modules.mamba_simple import Mamba
5
+ from mamba_ssm.modules.mamba2 import Mamba2
6
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
mamba/build/lib/mamba_ssm/distributed/__init__.py ADDED
File without changes
mamba/build/lib/mamba_ssm/distributed/distributed_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.distributed import ProcessGroup
6
+
7
+ # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
8
+ # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
9
+ # version of PyTorch. The following 4 lines are for backward compatibility with
10
+ # older PyTorch.
11
+ if "all_gather_into_tensor" not in dir(torch.distributed):
12
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
13
+ if "reduce_scatter_tensor" not in dir(torch.distributed):
14
+ torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
15
+
16
+
17
+ # Raw operation, does not support autograd, but does support async
18
+ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
19
+ world_size = torch.distributed.get_world_size(process_group)
20
+ output = torch.empty(
21
+ world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
22
+ )
23
+ handle = torch.distributed.all_gather_into_tensor(
24
+ output, input_.contiguous(), group=process_group, async_op=async_op
25
+ )
26
+ return output, handle
27
+
28
+
29
+ # Raw operation, does not support autograd, but does support async
30
+ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
31
+ world_size = torch.distributed.get_world_size(process_group)
32
+ assert input_.shape[0] % world_size == 0
33
+ output = torch.empty(
34
+ input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
35
+ )
36
+ handle = torch.distributed.reduce_scatter_tensor(
37
+ output, input_.contiguous(), group=process_group, async_op=async_op
38
+ )
39
+ return output, handle
40
+
41
+
42
+ # Raw operation, does not support autograd, but does support async
43
+ def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
44
+ input_ = input_.contiguous()
45
+ handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
46
+ return input_, handle
47
+
48
+
49
+ class AllGatherFunc(torch.autograd.Function):
50
+ """Gather the input from sequence parallel region and concatenate."""
51
+
52
+ @staticmethod
53
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
54
+ ctx.process_group = process_group
55
+ output, _ = all_gather_raw(input_, process_group)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output: Tensor):
60
+ grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
61
+ return grad_input, None
62
+
63
+
64
+ # Supports autograd, but does not support async
65
+ all_gather = AllGatherFunc.apply
66
+
67
+
68
+ class ReduceScatterFunc(torch.autograd.Function):
69
+ """Reduce scatter the input from the sequence parallel region and concatenate."""
70
+
71
+ @staticmethod
72
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
73
+ ctx.process_group = process_group
74
+ output, _ = reduce_scatter_raw(input_, process_group)
75
+ return output
76
+
77
+ @staticmethod
78
+ def backward(ctx, grad_output: Tensor):
79
+ grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
80
+ return grad_input, None
81
+
82
+
83
+ # Supports autograd, but does not support async
84
+ reduce_scatter = ReduceScatterFunc.apply
85
+
86
+
87
+ class AllReduceFunc(torch.autograd.Function):
88
+ """Gather the input from sequence parallel region and concatenate."""
89
+
90
+ @staticmethod
91
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
92
+ ctx.process_group = process_group
93
+ output, _ = all_reduce_raw(input_, process_group)
94
+ return output
95
+
96
+ @staticmethod
97
+ def backward(ctx, grad_output: Tensor):
98
+ return grad_output, None
99
+
100
+
101
+ # Supports autograd, but does not support async
102
+ all_reduce = AllReduceFunc.apply
103
+
104
+
105
+ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
106
+ # We want to iterate over parameters with _shared_params=True in the same order,
107
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
108
+ pamams_shared = {
109
+ name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
110
+ }
111
+ for _, p in sorted(pamams_shared.items()):
112
+ with torch.no_grad():
113
+ # Broadcast needs src to be global rank, not group rank
114
+ torch.distributed.broadcast(
115
+ p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
116
+ )
117
+
118
+
119
+ # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
120
+ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
121
+ # We want to iterate over parameters with _sequence_parallel=True in the same order,
122
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
123
+ params_seqparallel = {
124
+ name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
125
+ }
126
+ grads = [p.grad for _, p in sorted(params_seqparallel.items())]
127
+ if grads:
128
+ with torch.no_grad():
129
+ coalesced = torch._utils._flatten_dense_tensors(grads)
130
+ torch.distributed.all_reduce(coalesced, group=process_group)
131
+ for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
132
+ buf.copy_(synced)
133
+
134
+
135
+ def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
136
+ """Get the dim for the local rank derived from splitting dim on world_size processes.
137
+
138
+ The split may not be even across the world_size processes.
139
+ """
140
+ multiple = dim // multiple_of
141
+ div = multiple // world_size
142
+ mod = multiple % world_size
143
+ local_multiple = div + int(local_rank < mod)
144
+ return local_multiple * multiple_of
mamba/build/lib/mamba_ssm/distributed/tensor_parallel.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from torch.cuda.amp import custom_bwd, custom_fwd
10
+ from torch.distributed import ProcessGroup
11
+
12
+ from einops import rearrange
13
+
14
+ from mamba_ssm.distributed.distributed_utils import (
15
+ all_gather_raw,
16
+ all_reduce,
17
+ all_reduce_raw,
18
+ reduce_scatter,
19
+ reduce_scatter_raw,
20
+ )
21
+
22
+
23
+ class ParallelLinearFunc(torch.autograd.Function):
24
+ @staticmethod
25
+ @custom_fwd
26
+ def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
+ """
28
+ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
+ with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
+ """
31
+ ctx.compute_weight_gradient = weight.requires_grad
32
+ ctx.process_group = process_group
33
+ ctx.sequence_parallel = sequence_parallel
34
+
35
+ if torch.is_autocast_enabled():
36
+ x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
+ x = x.contiguous()
38
+ if process_group is not None and sequence_parallel:
39
+ # We want to kick off the all_gather early, before weight dtype conversion
40
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
+ else:
42
+ total_x = x
43
+
44
+ if torch.is_autocast_enabled():
45
+ weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
+ bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
47
+ weight = weight.contiguous()
48
+ if process_group is not None and sequence_parallel:
49
+ handle_x.wait()
50
+ batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
51
+ batch_dim = batch_shape.numel()
52
+ # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
53
+ output = F.linear(total_x, weight, bias)
54
+ if ctx.compute_weight_gradient:
55
+ ctx.save_for_backward(x, weight)
56
+ else:
57
+ ctx.save_for_backward(weight)
58
+ return output
59
+
60
+ @staticmethod
61
+ @custom_bwd
62
+ def backward(ctx, grad_output):
63
+ grad_output = grad_output.contiguous()
64
+ process_group = ctx.process_group
65
+ sequence_parallel = ctx.sequence_parallel
66
+ if ctx.compute_weight_gradient:
67
+ x, weight = ctx.saved_tensors
68
+ if process_group is not None and sequence_parallel:
69
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
70
+ else:
71
+ total_x = x
72
+ else:
73
+ (weight,) = ctx.saved_tensors
74
+ total_x = None
75
+ batch_shape = grad_output.shape[:-1]
76
+ batch_dim = batch_shape.numel()
77
+ grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
78
+ if ctx.needs_input_grad[0]:
79
+ grad_input = F.linear(grad_output, weight.t())
80
+ grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
81
+ if process_group is not None:
82
+ reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
83
+ grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
84
+ else:
85
+ grad_input = None
86
+ if ctx.needs_input_grad[1]:
87
+ assert ctx.compute_weight_gradient
88
+ if process_group is not None and sequence_parallel:
89
+ handle_x.wait()
90
+ grad_weight = torch.einsum(
91
+ "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
92
+ )
93
+ else:
94
+ grad_weight = None
95
+ grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
96
+ if process_group is not None and ctx.needs_input_grad[0]:
97
+ handle_grad_input.wait()
98
+ return grad_input, grad_weight, grad_bias, None, None
99
+
100
+
101
+ def parallel_linear_func(
102
+ x: Tensor,
103
+ weight: Tensor,
104
+ bias: Optional[Tensor] = None,
105
+ process_group: Optional[ProcessGroup] = None,
106
+ sequence_parallel: bool = True,
107
+ ):
108
+ return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
109
+
110
+
111
+ class ColumnParallelLinear(nn.Linear):
112
+ def __init__(
113
+ self,
114
+ in_features: int,
115
+ out_features: int,
116
+ process_group: ProcessGroup,
117
+ bias: bool = True,
118
+ sequence_parallel=True,
119
+ multiple_of=1,
120
+ device=None,
121
+ dtype=None,
122
+ ) -> None:
123
+ world_size = torch.distributed.get_world_size(process_group)
124
+ if out_features % multiple_of:
125
+ raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
126
+ multiple = out_features // multiple_of
127
+ # We want to split @multiple across world_size, but it could be an uneven split
128
+ div = multiple // world_size
129
+ mod = multiple % world_size
130
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
131
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
132
+ super().__init__(
133
+ in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
134
+ )
135
+ self.process_group = process_group
136
+ self.sequence_parallel = sequence_parallel
137
+
138
+ def forward(self, x):
139
+ # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
140
+ # we do an all_gather of x before doing the matmul.
141
+ # If not, then the input is already gathered.
142
+ return parallel_linear_func(
143
+ x,
144
+ self.weight,
145
+ self.bias,
146
+ process_group=self.process_group,
147
+ sequence_parallel=self.sequence_parallel,
148
+ )
149
+
150
+
151
+ class RowParallelLinear(nn.Linear):
152
+ def __init__(
153
+ self,
154
+ in_features: int,
155
+ out_features: int,
156
+ process_group: ProcessGroup,
157
+ bias: bool = True,
158
+ sequence_parallel=True,
159
+ multiple_of=1,
160
+ device=None,
161
+ dtype=None,
162
+ ) -> None:
163
+ world_size = torch.distributed.get_world_size(process_group)
164
+ rank = torch.distributed.get_rank(process_group)
165
+ if in_features % multiple_of:
166
+ raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
167
+ multiple = in_features // multiple_of
168
+ # We want to split @multiple across world_size, but it could be an uneven split
169
+ div = multiple // world_size
170
+ mod = multiple % world_size
171
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
172
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
173
+ # Only rank 0 will have bias
174
+ super().__init__(
175
+ local_multiple * multiple_of,
176
+ out_features,
177
+ bias=bias and rank == 0,
178
+ device=device,
179
+ dtype=dtype,
180
+ )
181
+ self.process_group = process_group
182
+ self.sequence_parallel = sequence_parallel
183
+
184
+ def forward(self, x):
185
+ """
186
+ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
187
+ a reduce_scatter of the result.
188
+ """
189
+ out = parallel_linear_func(x, self.weight, self.bias)
190
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
191
+ return reduce_fn(out, self.process_group)
192
+
193
+
194
+ class VocabParallelEmbedding(nn.Embedding):
195
+ def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
196
+ self.process_group = process_group
197
+ if process_group is not None:
198
+ world_size = torch.distributed.get_world_size(process_group)
199
+ if num_embeddings % world_size != 0:
200
+ raise ValueError(
201
+ f"num_embeddings ({num_embeddings}) must be divisible by "
202
+ f"world_size ({world_size})"
203
+ )
204
+ if world_size > 1 and padding_idx is not None:
205
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
206
+ else:
207
+ world_size = 1
208
+ super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
209
+
210
+ def forward(self, input: Tensor) -> Tensor:
211
+ if self.process_group is None:
212
+ return super().forward(input)
213
+ else:
214
+ rank = torch.distributed.get_rank(self.process_group)
215
+ vocab_size = self.num_embeddings
216
+ vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
217
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
218
+ input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
219
+ input = input - vocab_start_index
220
+ input[input_ids_mask] = 0
221
+ embeddings = super().forward(input)
222
+ embeddings[input_ids_mask] = 0.0
223
+ return embeddings
224
+
225
+
226
+ class ColumnParallelEmbedding(nn.Embedding):
227
+ def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
228
+ self.process_group = process_group
229
+ if process_group is not None:
230
+ world_size = torch.distributed.get_world_size(process_group)
231
+ if embedding_dim % world_size != 0:
232
+ raise ValueError(
233
+ f"embedding_dim ({embedding_dim}) must be divisible by "
234
+ f"world_size ({world_size})"
235
+ )
236
+ else:
237
+ world_size = 1
238
+ super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
239
+
240
+
241
+ class ParallelEmbeddings(nn.Module):
242
+ def __init__(
243
+ self,
244
+ embed_dim,
245
+ vocab_size,
246
+ max_position_embeddings,
247
+ process_group,
248
+ padding_idx=None,
249
+ sequence_parallel=True,
250
+ device=None,
251
+ dtype=None,
252
+ ):
253
+ """
254
+ If max_position_embeddings <= 0, there's no position embeddings
255
+ """
256
+ factory_kwargs = {"device": device, "dtype": dtype}
257
+ super().__init__()
258
+ self.process_group = process_group
259
+ self.sequence_parallel = sequence_parallel
260
+ self.word_embeddings = VocabParallelEmbedding(
261
+ vocab_size,
262
+ embed_dim,
263
+ padding_idx=padding_idx,
264
+ process_group=process_group,
265
+ **factory_kwargs,
266
+ )
267
+ self.max_position_embeddings = max_position_embeddings
268
+ if self.max_position_embeddings > 0:
269
+ self.position_embeddings = ColumnParallelEmbedding(
270
+ max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
271
+ )
272
+
273
+ def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
274
+ """
275
+ input_ids: (batch, seqlen)
276
+ position_ids: (batch, seqlen)
277
+ """
278
+ batch_size, seqlen = input_ids.shape
279
+ world_size = torch.distributed.get_world_size(self.process_group)
280
+ embeddings = self.word_embeddings(input_ids)
281
+ if self.max_position_embeddings > 0:
282
+ if position_ids is None:
283
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
284
+ position_embeddings = self.position_embeddings(position_ids)
285
+ if world_size <= 1:
286
+ embeddings = embeddings + position_embeddings
287
+ else:
288
+ partition_dim = self.position_embeddings.embedding_dim
289
+ rank = torch.distributed.get_rank(self.process_group)
290
+ embeddings[
291
+ ..., rank * partition_dim : (rank + 1) * partition_dim
292
+ ] += position_embeddings
293
+ if combine_batch_seqlen_dim:
294
+ embeddings = rearrange(embeddings, "b s d -> (b s) d")
295
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
296
+ return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
mamba/build/lib/mamba_ssm/models/__init__.py ADDED
File without changes
mamba/build/lib/mamba_ssm/models/config_mamba.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class MambaConfig:
6
+
7
+ d_model: int = 2560
8
+ d_intermediate: int = 0
9
+ n_layer: int = 64
10
+ vocab_size: int = 50277
11
+ ssm_cfg: dict = field(default_factory=dict)
12
+ attn_layer_idx: list = field(default_factory=list)
13
+ attn_cfg: dict = field(default_factory=dict)
14
+ rms_norm: bool = True
15
+ residual_in_fp32: bool = True
16
+ fused_add_norm: bool = True
17
+ pad_vocab_size_multiple: int = 8
18
+ tie_embeddings: bool = True
mamba/build/lib/mamba_ssm/models/mixer_seq_simple.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+ import json
6
+ import os
7
+ import copy
8
+
9
+ from collections import namedtuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from mamba_ssm.models.config_mamba import MambaConfig
15
+ from mamba_ssm.modules.mamba_simple import Mamba
16
+ from mamba_ssm.modules.mamba2 import Mamba2
17
+ from mamba_ssm.modules.mha import MHA
18
+ from mamba_ssm.modules.mlp import GatedMLP
19
+ from mamba_ssm.modules.block import Block
20
+ from mamba_ssm.utils.generation import GenerationMixin
21
+ from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
22
+
23
+ try:
24
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
+ except ImportError:
26
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
+
28
+
29
+ def create_block(
30
+ d_model,
31
+ d_intermediate,
32
+ ssm_cfg=None,
33
+ attn_layer_idx=None,
34
+ attn_cfg=None,
35
+ norm_epsilon=1e-5,
36
+ rms_norm=False,
37
+ residual_in_fp32=False,
38
+ fused_add_norm=False,
39
+ layer_idx=None,
40
+ device=None,
41
+ dtype=None,
42
+ ):
43
+ if ssm_cfg is None:
44
+ ssm_cfg = {}
45
+ if attn_layer_idx is None:
46
+ attn_layer_idx = []
47
+ if attn_cfg is None:
48
+ attn_cfg = {}
49
+ factory_kwargs = {"device": device, "dtype": dtype}
50
+ if layer_idx not in attn_layer_idx:
51
+ # Create a copy of the config to modify
52
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
+ ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
55
+ raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
56
+ mixer_cls = partial(
57
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
58
+ layer_idx=layer_idx,
59
+ **ssm_cfg,
60
+ **factory_kwargs
61
+ )
62
+ else:
63
+ mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
64
+ norm_cls = partial(
65
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
66
+ )
67
+ if d_intermediate == 0:
68
+ mlp_cls = nn.Identity
69
+ else:
70
+ mlp_cls = partial(
71
+ GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
72
+ )
73
+ block = Block(
74
+ d_model,
75
+ mixer_cls,
76
+ mlp_cls,
77
+ norm_cls=norm_cls,
78
+ fused_add_norm=fused_add_norm,
79
+ residual_in_fp32=residual_in_fp32,
80
+ )
81
+ block.layer_idx = layer_idx
82
+ return block
83
+
84
+
85
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
86
+ def _init_weights(
87
+ module,
88
+ n_layer,
89
+ initializer_range=0.02, # Now only used for embedding layer.
90
+ rescale_prenorm_residual=True,
91
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
92
+ ):
93
+ if isinstance(module, nn.Linear):
94
+ if module.bias is not None:
95
+ if not getattr(module.bias, "_no_reinit", False):
96
+ nn.init.zeros_(module.bias)
97
+ elif isinstance(module, nn.Embedding):
98
+ nn.init.normal_(module.weight, std=initializer_range)
99
+
100
+ if rescale_prenorm_residual:
101
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
102
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
103
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
104
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
105
+ #
106
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
107
+ for name, p in module.named_parameters():
108
+ if name in ["out_proj.weight", "fc2.weight"]:
109
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
110
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
111
+ # We need to reinit p since this code could be called multiple times
112
+ # Having just p *= scale would repeatedly scale it down
113
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
114
+ with torch.no_grad():
115
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
116
+
117
+
118
+ class MixerModel(nn.Module):
119
+ def __init__(
120
+ self,
121
+ d_model: int,
122
+ n_layer: int,
123
+ d_intermediate: int,
124
+ vocab_size: int,
125
+ ssm_cfg=None,
126
+ attn_layer_idx=None,
127
+ attn_cfg=None,
128
+ norm_epsilon: float = 1e-5,
129
+ rms_norm: bool = False,
130
+ initializer_cfg=None,
131
+ fused_add_norm=False,
132
+ residual_in_fp32=False,
133
+ device=None,
134
+ dtype=None,
135
+ ) -> None:
136
+ factory_kwargs = {"device": device, "dtype": dtype}
137
+ super().__init__()
138
+ self.residual_in_fp32 = residual_in_fp32
139
+
140
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
141
+
142
+ # We change the order of residual and layer norm:
143
+ # Instead of LN -> Attn / MLP -> Add, we do:
144
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
145
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
146
+ # This is for performance reason: we can fuse add + layer_norm.
147
+ self.fused_add_norm = fused_add_norm
148
+ if self.fused_add_norm:
149
+ if layer_norm_fn is None or rms_norm_fn is None:
150
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
151
+
152
+ self.layers = nn.ModuleList(
153
+ [
154
+ create_block(
155
+ d_model,
156
+ d_intermediate=d_intermediate,
157
+ ssm_cfg=ssm_cfg,
158
+ attn_layer_idx=attn_layer_idx,
159
+ attn_cfg=attn_cfg,
160
+ norm_epsilon=norm_epsilon,
161
+ rms_norm=rms_norm,
162
+ residual_in_fp32=residual_in_fp32,
163
+ fused_add_norm=fused_add_norm,
164
+ layer_idx=i,
165
+ **factory_kwargs,
166
+ )
167
+ for i in range(n_layer)
168
+ ]
169
+ )
170
+
171
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
172
+ d_model, eps=norm_epsilon, **factory_kwargs
173
+ )
174
+
175
+ self.apply(
176
+ partial(
177
+ _init_weights,
178
+ n_layer=n_layer,
179
+ **(initializer_cfg if initializer_cfg is not None else {}),
180
+ n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
181
+ )
182
+ )
183
+
184
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
185
+ return {
186
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
187
+ for i, layer in enumerate(self.layers)
188
+ }
189
+
190
+ def forward(self, input_ids, inference_params=None, **mixer_kwargs):
191
+ hidden_states = self.embedding(input_ids)
192
+ residual = None
193
+ for layer in self.layers:
194
+ hidden_states, residual = layer(
195
+ hidden_states, residual, inference_params=inference_params, **mixer_kwargs
196
+ )
197
+ if not self.fused_add_norm:
198
+ residual = (hidden_states + residual) if residual is not None else hidden_states
199
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
200
+ else:
201
+ # Set prenorm=False here since we don't need the residual
202
+ hidden_states = layer_norm_fn(
203
+ hidden_states,
204
+ self.norm_f.weight,
205
+ self.norm_f.bias,
206
+ eps=self.norm_f.eps,
207
+ residual=residual,
208
+ prenorm=False,
209
+ residual_in_fp32=self.residual_in_fp32,
210
+ is_rms_norm=isinstance(self.norm_f, RMSNorm)
211
+ )
212
+ return hidden_states
213
+
214
+
215
+ class MambaLMHeadModel(nn.Module, GenerationMixin):
216
+
217
+ def __init__(
218
+ self,
219
+ config: MambaConfig,
220
+ initializer_cfg=None,
221
+ device=None,
222
+ dtype=None,
223
+ ) -> None:
224
+ self.config = config
225
+ d_model = config.d_model
226
+ n_layer = config.n_layer
227
+ d_intermediate = config.d_intermediate
228
+ vocab_size = config.vocab_size
229
+ ssm_cfg = config.ssm_cfg
230
+ attn_layer_idx = config.attn_layer_idx
231
+ attn_cfg = config.attn_cfg
232
+ rms_norm = config.rms_norm
233
+ residual_in_fp32 = config.residual_in_fp32
234
+ fused_add_norm = config.fused_add_norm
235
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
236
+ factory_kwargs = {"device": device, "dtype": dtype}
237
+
238
+ super().__init__()
239
+ if vocab_size % pad_vocab_size_multiple != 0:
240
+ vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
241
+ self.backbone = MixerModel(
242
+ d_model=d_model,
243
+ n_layer=n_layer,
244
+ d_intermediate=d_intermediate,
245
+ vocab_size=vocab_size,
246
+ ssm_cfg=ssm_cfg,
247
+ attn_layer_idx=attn_layer_idx,
248
+ attn_cfg=attn_cfg,
249
+ rms_norm=rms_norm,
250
+ initializer_cfg=initializer_cfg,
251
+ fused_add_norm=fused_add_norm,
252
+ residual_in_fp32=residual_in_fp32,
253
+ **factory_kwargs,
254
+ )
255
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
256
+
257
+ # Initialize weights and apply final processing
258
+ self.apply(
259
+ partial(
260
+ _init_weights,
261
+ n_layer=n_layer,
262
+ **(initializer_cfg if initializer_cfg is not None else {}),
263
+ )
264
+ )
265
+ self.tie_weights()
266
+
267
+ def tie_weights(self):
268
+ if self.config.tie_embeddings:
269
+ self.lm_head.weight = self.backbone.embedding.weight
270
+
271
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
272
+ return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
273
+
274
+ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
275
+ """
276
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
277
+ num_last_tokens: if > 0, only return the logits for the last n tokens
278
+ """
279
+ hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
280
+ if num_last_tokens > 0:
281
+ hidden_states = hidden_states[:, -num_last_tokens:]
282
+ lm_logits = self.lm_head(hidden_states)
283
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
284
+ return CausalLMOutput(logits=lm_logits)
285
+
286
+ @classmethod
287
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
288
+ config_data = load_config_hf(pretrained_model_name)
289
+ config = MambaConfig(**config_data)
290
+ model = cls(config, device=device, dtype=dtype, **kwargs)
291
+ model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
292
+ return model
293
+
294
+ def save_pretrained(self, save_directory):
295
+ """
296
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
297
+ Save the model and its configuration file to a directory.
298
+ """
299
+ # Ensure save_directory exists
300
+ os.makedirs(save_directory, exist_ok=True)
301
+
302
+ # Save the model's state_dict
303
+ model_path = os.path.join(save_directory, 'pytorch_model.bin')
304
+ torch.save(self.state_dict(), model_path)
305
+
306
+ # Save the configuration of the model
307
+ config_path = os.path.join(save_directory, 'config.json')
308
+ with open(config_path, 'w') as f:
309
+ json.dump(self.config.__dict__, f, indent=4)
mamba/build/lib/mamba_ssm/modules/__init__.py ADDED
File without changes
mamba/build/lib/mamba_ssm/modules/block.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+
7
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
8
+
9
+
10
+ class Block(nn.Module):
11
+ def __init__(
12
+ self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
13
+ ):
14
+ """
15
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
16
+
17
+ This Block has a slightly different structure compared to a regular
18
+ prenorm Transformer block.
19
+ The standard block is: LN -> MHA/MLP -> Add.
20
+ [Ref: https://arxiv.org/abs/2002.04745]
21
+ Here we have: Add -> LN -> Mixer, returning both
22
+ the hidden_states (output of the mixer) and the residual.
23
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
24
+ The residual needs to be provided (except for the very first block).
25
+ """
26
+ super().__init__()
27
+ self.residual_in_fp32 = residual_in_fp32
28
+ self.fused_add_norm = fused_add_norm
29
+ self.norm = norm_cls(dim)
30
+ self.mixer = mixer_cls(dim)
31
+ if mlp_cls is not nn.Identity:
32
+ self.norm2 = norm_cls(dim)
33
+ self.mlp = mlp_cls(dim)
34
+ else:
35
+ self.mlp = None
36
+ if self.fused_add_norm:
37
+ assert RMSNorm is not None, "RMSNorm import fails"
38
+ assert isinstance(
39
+ self.norm, (nn.LayerNorm, RMSNorm)
40
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
41
+
42
+ def forward(
43
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs
44
+ ):
45
+ r"""Pass the input through the encoder layer.
46
+
47
+ Args:
48
+ hidden_states: the sequence to the encoder layer (required).
49
+ residual: hidden_states = Mixer(LN(residual))
50
+ """
51
+ if not self.fused_add_norm:
52
+ residual = (hidden_states + residual) if residual is not None else hidden_states
53
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
54
+ if self.residual_in_fp32:
55
+ residual = residual.to(torch.float32)
56
+ else:
57
+ hidden_states, residual = layer_norm_fn(
58
+ hidden_states,
59
+ self.norm.weight,
60
+ self.norm.bias,
61
+ residual=residual,
62
+ prenorm=True,
63
+ residual_in_fp32=self.residual_in_fp32,
64
+ eps=self.norm.eps,
65
+ is_rms_norm=isinstance(self.norm, RMSNorm)
66
+ )
67
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
68
+
69
+ if self.mlp is not None:
70
+ if not self.fused_add_norm:
71
+ residual = hidden_states + residual
72
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
73
+ if self.residual_in_fp32:
74
+ residual = residual.to(torch.float32)
75
+ else:
76
+ hidden_states, residual = layer_norm_fn(
77
+ hidden_states,
78
+ self.norm2.weight,
79
+ self.norm2.bias,
80
+ residual=residual,
81
+ prenorm=True,
82
+ residual_in_fp32=self.residual_in_fp32,
83
+ eps=self.norm2.eps,
84
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
85
+ )
86
+ hidden_states = self.mlp(hidden_states)
87
+
88
+ return hidden_states, residual
89
+
90
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
91
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
mamba/build/lib/mamba_ssm/modules/mamba2.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ try:
12
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
13
+ except ImportError:
14
+ causal_conv1d_fn, causal_conv1d_update = None, None
15
+
16
+ try:
17
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
18
+ except ImportError:
19
+ causal_conv1d_varlen_states = None
20
+
21
+ try:
22
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
23
+ except ImportError:
24
+ selective_state_update = None
25
+
26
+ from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
27
+
28
+ from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
29
+ from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter
30
+
31
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
32
+ from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
33
+
34
+ from huggingface_hub import PyTorchModelHubMixin
35
+
36
+
37
+ class Mamba2(nn.Module, PyTorchModelHubMixin):
38
+ def __init__(
39
+ self,
40
+ d_model,
41
+ d_state=128,
42
+ d_conv=4,
43
+ conv_init=None,
44
+ expand=2,
45
+ headdim=64,
46
+ d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
47
+ ngroups=1,
48
+ A_init_range=(1, 16),
49
+ D_has_hdim=False,
50
+ rmsnorm=True,
51
+ norm_before_gate=False,
52
+ dt_min=0.001,
53
+ dt_max=0.1,
54
+ dt_init_floor=1e-4,
55
+ dt_limit=(0.0, float("inf")),
56
+ bias=False,
57
+ conv_bias=True,
58
+ # Fused kernel and sharding options
59
+ chunk_size=256,
60
+ use_mem_eff_path=True,
61
+ layer_idx=None, # Absorb kwarg for general module
62
+ process_group=None,
63
+ sequence_parallel=True,
64
+ device=None,
65
+ dtype=None,
66
+ ):
67
+ factory_kwargs = {"device": device, "dtype": dtype}
68
+ super().__init__()
69
+ self.d_model = d_model
70
+ self.d_state = d_state
71
+ self.d_conv = d_conv
72
+ self.conv_init = conv_init
73
+ self.expand = expand
74
+ self.process_group = process_group
75
+ self.sequence_parallel = sequence_parallel
76
+ self.world_size = 1 if process_group is None else process_group.size()
77
+ self.local_rank = 0 if process_group is None else process_group.rank()
78
+ self.d_inner = (self.expand * self.d_model) // self.world_size
79
+ assert self.d_inner * self.world_size == self.expand * self.d_model
80
+ self.headdim = headdim
81
+ self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
82
+ assert ngroups % self.world_size == 0
83
+ self.ngroups = ngroups // self.world_size
84
+ assert self.d_ssm % self.headdim == 0
85
+ self.nheads = self.d_ssm // self.headdim
86
+ self.D_has_hdim = D_has_hdim
87
+ self.rmsnorm = rmsnorm
88
+ self.norm_before_gate = norm_before_gate
89
+ self.dt_limit = dt_limit
90
+ self.activation = "silu"
91
+ self.chunk_size = chunk_size
92
+ self.use_mem_eff_path = use_mem_eff_path
93
+ self.layer_idx = layer_idx
94
+
95
+ # Order: [z, x, B, C, dt]
96
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
97
+ if self.process_group is None:
98
+ self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
99
+ else:
100
+ self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,
101
+ process_group=self.process_group, sequence_parallel=self.sequence_parallel,
102
+ **factory_kwargs)
103
+
104
+ conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
105
+ self.conv1d = nn.Conv1d(
106
+ in_channels=conv_dim,
107
+ out_channels=conv_dim,
108
+ bias=conv_bias,
109
+ kernel_size=d_conv,
110
+ groups=conv_dim,
111
+ padding=d_conv - 1,
112
+ **factory_kwargs,
113
+ )
114
+ if self.conv_init is not None:
115
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
116
+
117
+ self.act = nn.SiLU()
118
+
119
+ # Initialize log dt bias
120
+ dt = torch.exp(
121
+ torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
122
+ + math.log(dt_min)
123
+ )
124
+ dt = torch.clamp(dt, min=dt_init_floor)
125
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
126
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
127
+ self.dt_bias = nn.Parameter(inv_dt)
128
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
129
+ # name.endswith("bias") in param_grouping.py
130
+ self.dt_bias._no_weight_decay = True
131
+
132
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
133
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
134
+ A_log = torch.log(A).to(dtype=dtype)
135
+ self.A_log = nn.Parameter(A_log)
136
+ self.A_log._no_weight_decay = True
137
+
138
+ # D "skip" parameter
139
+ self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))
140
+ self.D._no_weight_decay = True
141
+
142
+ if self.rmsnorm:
143
+ assert RMSNormGated is not None
144
+ self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
145
+ group_size=self.d_ssm // ngroups, **factory_kwargs)
146
+
147
+ if self.process_group is None:
148
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
149
+ else:
150
+ self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,
151
+ process_group=self.process_group, sequence_parallel=self.sequence_parallel,
152
+ **factory_kwargs)
153
+
154
+ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
155
+ """
156
+ u: (batch, seqlen, hidden_dim) if seqlen=None.
157
+ If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
158
+ split u during sequence parallel, we split the batch * seqlen dimension
159
+ (in case batch is small).
160
+ Returns: same shape as u
161
+ """
162
+ seqlen_og = seqlen
163
+ if seqlen is None:
164
+ batch, seqlen, dim = u.shape
165
+ else:
166
+ batch_seqlen, dim = u.shape
167
+ batch = batch_seqlen // seqlen
168
+
169
+ conv_state, ssm_state = None, None
170
+ if inference_params is not None:
171
+ inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
172
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
173
+ if inference_params.seqlen_offset > 0:
174
+ # The states are updated inplace
175
+ out, _, _ = self.step(u, conv_state, ssm_state)
176
+ return out
177
+
178
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
179
+ if seqlen_og is not None:
180
+ zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
181
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
182
+ A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
183
+ dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
184
+ if self.use_mem_eff_path and inference_params is None:
185
+ out = mamba_split_conv1d_scan_combined(
186
+ zxbcdt,
187
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
188
+ self.conv1d.bias,
189
+ self.dt_bias,
190
+ A,
191
+ D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
192
+ chunk_size=self.chunk_size,
193
+ seq_idx=seq_idx,
194
+ activation=self.activation,
195
+ rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
196
+ rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
197
+ outproj_weight=self.out_proj.weight,
198
+ outproj_bias=self.out_proj.bias,
199
+ headdim=None if self.D_has_hdim else self.headdim,
200
+ ngroups=self.ngroups,
201
+ norm_before_gate=self.norm_before_gate,
202
+ **dt_limit_kwargs,
203
+ )
204
+ if seqlen_og is not None:
205
+ out = rearrange(out, "b l d -> (b l) d")
206
+ if self.process_group is not None:
207
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
208
+ out = reduce_fn(out, self.process_group)
209
+ else:
210
+ d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
211
+ z0, x0, z, xBC, dt = torch.split(
212
+ zxbcdt,
213
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
214
+ dim=-1
215
+ )
216
+ if conv_state is not None:
217
+ if cu_seqlens is None:
218
+ # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
219
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
220
+ xBC_t = rearrange(xBC, "b l d -> b d l")
221
+ conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
222
+ else:
223
+ assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
224
+ assert batch == 1, "varlen inference only supports batch dimension 1"
225
+ conv_varlen_states = causal_conv1d_varlen_states(
226
+ xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
227
+ )
228
+ conv_state.copy_(conv_varlen_states)
229
+ assert self.activation in ["silu", "swish"]
230
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
231
+ assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
232
+ xBC = self.act(
233
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
234
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
235
+ else:
236
+ xBC = causal_conv1d_fn(
237
+ xBC.transpose(1, 2),
238
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
239
+ bias=self.conv1d.bias,
240
+ activation=self.activation,
241
+ seq_idx=seq_idx,
242
+ ).transpose(1, 2)
243
+ x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
244
+ y = mamba_chunk_scan_combined(
245
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
246
+ dt,
247
+ A,
248
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
249
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
250
+ chunk_size=self.chunk_size,
251
+ D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
252
+ z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
253
+ dt_bias=self.dt_bias,
254
+ dt_softplus=True,
255
+ seq_idx=seq_idx,
256
+ cu_seqlens=cu_seqlens,
257
+ **dt_limit_kwargs,
258
+ return_final_states=ssm_state is not None,
259
+ return_varlen_states=cu_seqlens is not None and inference_params is not None,
260
+ )
261
+ if ssm_state is not None:
262
+ y, last_state, *rest = y
263
+ if cu_seqlens is None:
264
+ ssm_state.copy_(last_state)
265
+ else:
266
+ varlen_states = rest[0]
267
+ ssm_state.copy_(varlen_states)
268
+ y = rearrange(y, "b l h p -> b l (h p)")
269
+ if self.rmsnorm:
270
+ y = self.norm(y, z)
271
+ if d_mlp > 0:
272
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
273
+ if seqlen_og is not None:
274
+ y = rearrange(y, "b l d -> (b l) d")
275
+ out = self.out_proj(y)
276
+ return out
277
+
278
+ def step(self, hidden_states, conv_state, ssm_state):
279
+ dtype = hidden_states.dtype
280
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
281
+ zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
282
+ d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
283
+ z0, x0, z, xBC, dt = torch.split(
284
+ zxbcdt,
285
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
286
+ dim=-1
287
+ )
288
+
289
+ # Conv step
290
+ if causal_conv1d_update is None:
291
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
292
+ conv_state[:, :, -1] = xBC
293
+ xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
294
+ if self.conv1d.bias is not None:
295
+ xBC = xBC + self.conv1d.bias
296
+ xBC = self.act(xBC).to(dtype=dtype)
297
+ else:
298
+ xBC = causal_conv1d_update(
299
+ xBC,
300
+ conv_state,
301
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
302
+ self.conv1d.bias,
303
+ self.activation,
304
+ )
305
+
306
+ x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
307
+ A = -torch.exp(self.A_log.float()) # (nheads,)
308
+
309
+ # SSM step
310
+ if selective_state_update is None:
311
+ assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
312
+ # Discretize A and B
313
+ dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
314
+ dA = torch.exp(dt * A) # (batch, nheads)
315
+ x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
316
+ dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
317
+ ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
318
+ y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
319
+ y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
320
+ y = rearrange(y, "b h p -> b (h p)")
321
+ if not self.rmsnorm:
322
+ y = y * self.act(z) # (B D)
323
+ else:
324
+ A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
325
+ dt = repeat(dt, "b h -> b h p", p=self.headdim)
326
+ dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
327
+ D = repeat(self.D, "h -> h p", p=self.headdim)
328
+ B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
329
+ C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
330
+ x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
331
+ if not self.rmsnorm:
332
+ z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
333
+ y = selective_state_update(
334
+ ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
335
+ dt_bias=dt_bias, dt_softplus=True
336
+ )
337
+ y = rearrange(y, "b h p -> b (h p)")
338
+ if self.rmsnorm:
339
+ y = self.norm(y, z)
340
+ if d_mlp > 0:
341
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
342
+ out = self.out_proj(y)
343
+ return out.unsqueeze(1), conv_state, ssm_state
344
+
345
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
346
+ device = self.out_proj.weight.device
347
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
348
+ conv_state = torch.zeros(
349
+ batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype
350
+ ).transpose(1, 2)
351
+ ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
352
+ ssm_state = torch.zeros(
353
+ batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
354
+ )
355
+ return conv_state, ssm_state
356
+
357
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
358
+ assert self.layer_idx is not None
359
+ if self.layer_idx not in inference_params.key_value_memory_dict:
360
+ batch_shape = (batch_size,)
361
+ conv_state = torch.zeros(
362
+ batch_size,
363
+ self.d_conv,
364
+ self.conv1d.weight.shape[0],
365
+ device=self.conv1d.weight.device,
366
+ dtype=self.conv1d.weight.dtype,
367
+ ).transpose(1, 2)
368
+ ssm_state = torch.zeros(
369
+ batch_size,
370
+ self.nheads,
371
+ self.headdim,
372
+ self.d_state,
373
+ device=self.in_proj.weight.device,
374
+ dtype=self.in_proj.weight.dtype,
375
+ )
376
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
377
+ else:
378
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
379
+ # TODO: What if batch size changes between generation, and we reuse the same states?
380
+ if initialize_states:
381
+ conv_state.zero_()
382
+ ssm_state.zero_()
383
+ return conv_state, ssm_state
mamba/build/lib/mamba_ssm/modules/mamba2_simple.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from einops import rearrange, repeat
9
+
10
+ try:
11
+ from causal_conv1d import causal_conv1d_fn
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+
15
+ try:
16
+ from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
17
+ except ImportError:
18
+ RMSNormGated, LayerNorm = None, None
19
+
20
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
21
+ from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
22
+
23
+
24
+ class Mamba2Simple(nn.Module):
25
+ def __init__(
26
+ self,
27
+ d_model,
28
+ d_state=64,
29
+ d_conv=4,
30
+ conv_init=None,
31
+ expand=2,
32
+ headdim=128,
33
+ ngroups=1,
34
+ A_init_range=(1, 16),
35
+ dt_min=0.001,
36
+ dt_max=0.1,
37
+ dt_init_floor=1e-4,
38
+ dt_limit=(0.0, float("inf")),
39
+ learnable_init_states=False,
40
+ activation="swish",
41
+ bias=False,
42
+ conv_bias=True,
43
+ # Fused kernel and sharding options
44
+ chunk_size=256,
45
+ use_mem_eff_path=True,
46
+ layer_idx=None, # Absorb kwarg for general module
47
+ device=None,
48
+ dtype=None,
49
+ ):
50
+ factory_kwargs = {"device": device, "dtype": dtype}
51
+ super().__init__()
52
+ self.d_model = d_model
53
+ self.d_state = d_state
54
+ self.d_conv = d_conv
55
+ self.conv_init = conv_init
56
+ self.expand = expand
57
+ self.d_inner = self.expand * self.d_model
58
+ self.headdim = headdim
59
+ self.ngroups = ngroups
60
+ assert self.d_inner % self.headdim == 0
61
+ self.nheads = self.d_inner // self.headdim
62
+ self.dt_limit = dt_limit
63
+ self.learnable_init_states = learnable_init_states
64
+ self.activation = activation
65
+ self.chunk_size = chunk_size
66
+ self.use_mem_eff_path = use_mem_eff_path
67
+ self.layer_idx = layer_idx
68
+
69
+ # Order: [z, x, B, C, dt]
70
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
71
+ self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
72
+
73
+ conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
74
+ self.conv1d = nn.Conv1d(
75
+ in_channels=conv_dim,
76
+ out_channels=conv_dim,
77
+ bias=conv_bias,
78
+ kernel_size=d_conv,
79
+ groups=conv_dim,
80
+ padding=d_conv - 1,
81
+ **factory_kwargs,
82
+ )
83
+ if self.conv_init is not None:
84
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
85
+ # self.conv1d.weight._no_weight_decay = True
86
+
87
+ if self.learnable_init_states:
88
+ self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))
89
+ self.init_states._no_weight_decay = True
90
+
91
+ self.act = nn.SiLU()
92
+
93
+ # Initialize log dt bias
94
+ dt = torch.exp(
95
+ torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
96
+ + math.log(dt_min)
97
+ )
98
+ dt = torch.clamp(dt, min=dt_init_floor)
99
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
100
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
101
+ self.dt_bias = nn.Parameter(inv_dt)
102
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
103
+ # name.endswith("bias") in param_grouping.py
104
+ self.dt_bias._no_weight_decay = True
105
+
106
+ # A parameter
107
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
108
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
109
+ A_log = torch.log(A).to(dtype=dtype)
110
+ self.A_log = nn.Parameter(A_log)
111
+ # self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
112
+ self.A_log._no_weight_decay = True
113
+
114
+ # D "skip" parameter
115
+ self.D = nn.Parameter(torch.ones(self.nheads, device=device))
116
+ self.D._no_weight_decay = True
117
+
118
+ # Extra normalization layer right before output projection
119
+ assert RMSNormGated is not None
120
+ self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)
121
+
122
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
123
+
124
+ def forward(self, u, seq_idx=None):
125
+ """
126
+ u: (B, L, D)
127
+ Returns: same shape as u
128
+ """
129
+ batch, seqlen, dim = u.shape
130
+
131
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
132
+ A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
133
+ initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None
134
+ dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
135
+
136
+ if self.use_mem_eff_path:
137
+ # Fully fused path
138
+ out = mamba_split_conv1d_scan_combined(
139
+ zxbcdt,
140
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
141
+ self.conv1d.bias,
142
+ self.dt_bias,
143
+ A,
144
+ D=self.D,
145
+ chunk_size=self.chunk_size,
146
+ seq_idx=seq_idx,
147
+ activation=self.activation,
148
+ rmsnorm_weight=self.norm.weight,
149
+ rmsnorm_eps=self.norm.eps,
150
+ outproj_weight=self.out_proj.weight,
151
+ outproj_bias=self.out_proj.bias,
152
+ headdim=self.headdim,
153
+ ngroups=self.ngroups,
154
+ norm_before_gate=False,
155
+ initial_states=initial_states,
156
+ **dt_limit_kwargs,
157
+ )
158
+ else:
159
+ z, xBC, dt = torch.split(
160
+ zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1
161
+ )
162
+ dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
163
+ assert self.activation in ["silu", "swish"]
164
+
165
+ # 1D Convolution
166
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
167
+ xBC = self.act(
168
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
169
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
170
+ xBC = xBC[:, :seqlen, :]
171
+ else:
172
+ xBC = causal_conv1d_fn(
173
+ x=xBC.transpose(1, 2),
174
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
175
+ bias=self.conv1d.bias,
176
+ activation=self.activation,
177
+ ).transpose(1, 2)
178
+
179
+ # Split into 3 main branches: X, B, C
180
+ # These correspond to V, K, Q respectively in the SSM/attention duality
181
+ x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
182
+ y = mamba_chunk_scan_combined(
183
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
184
+ dt,
185
+ A,
186
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
187
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
188
+ chunk_size=self.chunk_size,
189
+ D=self.D,
190
+ z=None,
191
+ seq_idx=seq_idx,
192
+ initial_states=initial_states,
193
+ **dt_limit_kwargs,
194
+ )
195
+ y = rearrange(y, "b l h p -> b l (h p)")
196
+
197
+ # Multiply "gate" branch and apply extra normalization layer
198
+ y = self.norm(y, z)
199
+ out = self.out_proj(y)
200
+ return out
mamba/build/lib/mamba_ssm/modules/mamba_simple.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from einops import rearrange, repeat
12
+
13
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
14
+
15
+ try:
16
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
17
+ except ImportError:
18
+ causal_conv1d_fn, causal_conv1d_update = None, None
19
+
20
+ try:
21
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
22
+ except ImportError:
23
+ selective_state_update = None
24
+
25
+ try:
26
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+
31
+ class Mamba(nn.Module):
32
+ def __init__(
33
+ self,
34
+ d_model,
35
+ d_state=16,
36
+ d_conv=4,
37
+ expand=2,
38
+ dt_rank="auto",
39
+ dt_min=0.001,
40
+ dt_max=0.1,
41
+ dt_init="random",
42
+ dt_scale=1.0,
43
+ dt_init_floor=1e-4,
44
+ conv_bias=True,
45
+ bias=False,
46
+ use_fast_path=True, # Fused kernel options
47
+ layer_idx=None,
48
+ device=None,
49
+ dtype=None,
50
+ ):
51
+ factory_kwargs = {"device": device, "dtype": dtype}
52
+ super().__init__()
53
+ self.d_model = d_model
54
+ self.d_state = d_state
55
+ self.d_conv = d_conv
56
+ self.expand = expand
57
+ self.d_inner = int(self.expand * self.d_model)
58
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
59
+ self.use_fast_path = use_fast_path
60
+ self.layer_idx = layer_idx
61
+
62
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
63
+
64
+ self.conv1d = nn.Conv1d(
65
+ in_channels=self.d_inner,
66
+ out_channels=self.d_inner,
67
+ bias=conv_bias,
68
+ kernel_size=d_conv,
69
+ groups=self.d_inner,
70
+ padding=d_conv - 1,
71
+ **factory_kwargs,
72
+ )
73
+
74
+ self.activation = "silu"
75
+ self.act = nn.SiLU()
76
+
77
+ self.x_proj = nn.Linear(
78
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
79
+ )
80
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
81
+
82
+ # Initialize special dt projection to preserve variance at initialization
83
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
84
+ if dt_init == "constant":
85
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
86
+ elif dt_init == "random":
87
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
92
+ dt = torch.exp(
93
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
94
+ + math.log(dt_min)
95
+ ).clamp(min=dt_init_floor)
96
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
97
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
98
+ with torch.no_grad():
99
+ self.dt_proj.bias.copy_(inv_dt)
100
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
101
+ self.dt_proj.bias._no_reinit = True
102
+
103
+ # S4D real initialization
104
+ A = repeat(
105
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
106
+ "n -> d n",
107
+ d=self.d_inner,
108
+ ).contiguous()
109
+ A_log = torch.log(A) # Keep A_log in fp32
110
+ self.A_log = nn.Parameter(A_log)
111
+ self.A_log._no_weight_decay = True
112
+
113
+ # D "skip" parameter
114
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
115
+ self.D._no_weight_decay = True
116
+
117
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
118
+
119
+ def forward(self, hidden_states, inference_params=None):
120
+ """
121
+ hidden_states: (B, L, D)
122
+ Returns: same shape as hidden_states
123
+ """
124
+ batch, seqlen, dim = hidden_states.shape
125
+
126
+ conv_state, ssm_state = None, None
127
+ if inference_params is not None:
128
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
129
+ if inference_params.seqlen_offset > 0:
130
+ # The states are updated inplace
131
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
132
+ return out
133
+
134
+ # We do matmul and transpose BLH -> HBL at the same time
135
+ xz = rearrange(
136
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
137
+ "d (b l) -> b d l",
138
+ l=seqlen,
139
+ )
140
+ if self.in_proj.bias is not None:
141
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
142
+
143
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
144
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
145
+ if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
146
+ out = mamba_inner_fn(
147
+ xz,
148
+ self.conv1d.weight,
149
+ self.conv1d.bias,
150
+ self.x_proj.weight,
151
+ self.dt_proj.weight,
152
+ self.out_proj.weight,
153
+ self.out_proj.bias,
154
+ A,
155
+ None, # input-dependent B
156
+ None, # input-dependent C
157
+ self.D.float(),
158
+ delta_bias=self.dt_proj.bias.float(),
159
+ delta_softplus=True,
160
+ )
161
+ else:
162
+ x, z = xz.chunk(2, dim=1)
163
+ # Compute short convolution
164
+ if conv_state is not None:
165
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
166
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
167
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
168
+ if causal_conv1d_fn is None:
169
+ x = self.act(self.conv1d(x)[..., :seqlen])
170
+ else:
171
+ assert self.activation in ["silu", "swish"]
172
+ x = causal_conv1d_fn(
173
+ x=x,
174
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
175
+ bias=self.conv1d.bias,
176
+ activation=self.activation,
177
+ )
178
+
179
+ # We're careful here about the layout, to avoid extra transposes.
180
+ # We want dt to have d as the slowest moving dimension
181
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
182
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
183
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
184
+ dt = self.dt_proj.weight @ dt.t()
185
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
186
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
187
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
188
+ assert self.activation in ["silu", "swish"]
189
+ y = selective_scan_fn(
190
+ x,
191
+ dt,
192
+ A,
193
+ B,
194
+ C,
195
+ self.D.float(),
196
+ z=z,
197
+ delta_bias=self.dt_proj.bias.float(),
198
+ delta_softplus=True,
199
+ return_last_state=ssm_state is not None,
200
+ )
201
+ if ssm_state is not None:
202
+ y, last_state = y
203
+ ssm_state.copy_(last_state)
204
+ y = rearrange(y, "b d l -> b l d")
205
+ out = self.out_proj(y)
206
+ return out
207
+
208
+ def step(self, hidden_states, conv_state, ssm_state):
209
+ dtype = hidden_states.dtype
210
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
211
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
212
+ x, z = xz.chunk(2, dim=-1) # (B D)
213
+
214
+ # Conv step
215
+ if causal_conv1d_update is None:
216
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
217
+ conv_state[:, :, -1] = x
218
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
219
+ if self.conv1d.bias is not None:
220
+ x = x + self.conv1d.bias
221
+ x = self.act(x).to(dtype=dtype)
222
+ else:
223
+ x = causal_conv1d_update(
224
+ x,
225
+ conv_state,
226
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
227
+ self.conv1d.bias,
228
+ self.activation,
229
+ )
230
+
231
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
232
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
233
+ # Don't add dt_bias here
234
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
235
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
236
+
237
+ # SSM step
238
+ if selective_state_update is None:
239
+ # Discretize A and B
240
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
241
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
242
+ dB = torch.einsum("bd,bn->bdn", dt, B)
243
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
244
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
245
+ y = y + self.D.to(dtype) * x
246
+ y = y * self.act(z) # (B D)
247
+ else:
248
+ y = selective_state_update(
249
+ ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
250
+ )
251
+
252
+ out = self.out_proj(y)
253
+ return out.unsqueeze(1), conv_state, ssm_state
254
+
255
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
256
+ device = self.out_proj.weight.device
257
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
258
+ conv_state = torch.zeros(
259
+ batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
260
+ )
261
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
262
+ # ssm_dtype = torch.float32
263
+ ssm_state = torch.zeros(
264
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
265
+ )
266
+ return conv_state, ssm_state
267
+
268
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
269
+ assert self.layer_idx is not None
270
+ if self.layer_idx not in inference_params.key_value_memory_dict:
271
+ batch_shape = (batch_size,)
272
+ conv_state = torch.zeros(
273
+ batch_size,
274
+ self.d_model * self.expand,
275
+ self.d_conv,
276
+ device=self.conv1d.weight.device,
277
+ dtype=self.conv1d.weight.dtype,
278
+ )
279
+ ssm_state = torch.zeros(
280
+ batch_size,
281
+ self.d_model * self.expand,
282
+ self.d_state,
283
+ device=self.dt_proj.weight.device,
284
+ dtype=self.dt_proj.weight.dtype,
285
+ # dtype=torch.float32,
286
+ )
287
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
288
+ else:
289
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
290
+ # TODO: What if batch size changes between generation, and we reuse the same states?
291
+ if initialize_states:
292
+ conv_state.zero_()
293
+ ssm_state.zero_()
294
+ return conv_state, ssm_state
mamba/build/lib/mamba_ssm/modules/mha.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ try:
11
+ from flash_attn import flash_attn_with_kvcache
12
+ except ImportError:
13
+ flash_attn_with_kvcache = None
14
+
15
+ try:
16
+ from flash_attn.layers.rotary import RotaryEmbedding
17
+ except ImportError:
18
+ RotaryEmbedding = None
19
+
20
+ try:
21
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
22
+ except ImportError:
23
+ causal_conv1d_fn, causal_conv1d_update = None, None
24
+
25
+
26
+ def _update_kv_cache(kv, inference_params, layer_idx):
27
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
28
+ # Pre-allocate memory for key-values for inference.
29
+ num_heads, head_dim = kv.shape[-2:]
30
+ assert layer_idx in inference_params.key_value_memory_dict
31
+ kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
32
+ # Adjust key and value for inference
33
+ batch_start = inference_params.batch_size_offset
34
+ batch_end = batch_start + kv.shape[0]
35
+ sequence_start = inference_params.seqlen_offset
36
+ sequence_end = sequence_start + kv.shape[1]
37
+ assert batch_end <= kv_cache.shape[0]
38
+ assert sequence_end <= kv_cache.shape[1]
39
+ assert kv_cache is not None
40
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
41
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
42
+
43
+
44
+ class MHA(nn.Module):
45
+ """Multi-head self-attention and cross-attention"""
46
+
47
+ def __init__(
48
+ self,
49
+ embed_dim,
50
+ num_heads,
51
+ num_heads_kv=None,
52
+ head_dim=None, # If None, use embed_dim // num_heads
53
+ mlp_dim=0,
54
+ qkv_proj_bias=True,
55
+ out_proj_bias=True,
56
+ softmax_scale=None,
57
+ causal=False,
58
+ layer_idx=None,
59
+ d_conv=0,
60
+ rotary_emb_dim=0,
61
+ rotary_emb_base=10000.0,
62
+ rotary_emb_interleaved=False,
63
+ device=None,
64
+ dtype=None,
65
+ ) -> None:
66
+ """
67
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
68
+ return_residual: whether to return the input x along with the output. This is for
69
+ performance reason: for post-norm architecture, returning the input allows us
70
+ to fuse the backward of nn.Linear with the residual connection.
71
+ """
72
+ factory_kwargs = {"device": device, "dtype": dtype}
73
+ super().__init__()
74
+ self.embed_dim = embed_dim
75
+ self.layer_idx = layer_idx
76
+ self.d_conv = d_conv
77
+ self.rotary_emb_dim = rotary_emb_dim
78
+ self.softmax_scale = softmax_scale
79
+ self.causal = causal
80
+
81
+ self.num_heads = num_heads
82
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
83
+ assert (
84
+ self.num_heads % self.num_heads_kv == 0
85
+ ), "num_heads must be divisible by num_heads_kv"
86
+ if head_dim is None:
87
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
88
+ self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
89
+ self.mlp_dim = math.ceil(mlp_dim / 256) * 256
90
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
91
+ out_dim = self.head_dim * self.num_heads
92
+
93
+ if self.rotary_emb_dim > 0:
94
+ assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
95
+ self.rotary_emb = RotaryEmbedding(
96
+ self.rotary_emb_dim,
97
+ base=rotary_emb_base,
98
+ interleaved=rotary_emb_interleaved,
99
+ device=device,
100
+ )
101
+
102
+ self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
103
+ if self.d_conv > 0:
104
+ self.conv1d = nn.Conv1d(
105
+ qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
106
+ **factory_kwargs
107
+ )
108
+ self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
109
+
110
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
111
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
112
+ device = self.out_proj.weight.device
113
+ if self.d_conv > 0:
114
+ conv_state = torch.zeros(
115
+ batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
116
+ )
117
+ else:
118
+ conv_state = None
119
+ kv_cache = torch.empty(
120
+ batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
121
+ )
122
+ return kv_cache, conv_state
123
+
124
+ def _update_kv_cache(self, kv, inference_params):
125
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
126
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
127
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
128
+
129
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
130
+ """
131
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
132
+ q: (batch_size, seqlen_q, nheads, head_dim)
133
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
134
+ """
135
+ assert inference_params is not None and inference_params.seqlen_offset > 0
136
+ if self.rotary_emb_dim > 0:
137
+ self.rotary_emb._update_cos_sin_cache(
138
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
139
+ )
140
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
141
+ else:
142
+ rotary_cos, rotary_sin = None, None
143
+ batch = q.shape[0]
144
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
145
+ kv_cache = kv_cache[:batch]
146
+ cache_seqlens = (
147
+ inference_params.lengths_per_sample[:batch]
148
+ if inference_params.lengths_per_sample is not None
149
+ else inference_params.seqlen_offset
150
+ )
151
+ assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
152
+ context = flash_attn_with_kvcache(
153
+ q,
154
+ kv_cache[:, :, 0],
155
+ kv_cache[:, :, 1],
156
+ kv[:, :, 0],
157
+ kv[:, :, 1],
158
+ rotary_cos=rotary_cos,
159
+ rotary_sin=rotary_sin,
160
+ cache_seqlens=cache_seqlens,
161
+ softmax_scale=self.softmax_scale,
162
+ causal=self.causal,
163
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
164
+ )
165
+ return context
166
+
167
+ def _update_kvcache_attention(self, q, kv, inference_params):
168
+ """Write kv to inference_params, then do attention"""
169
+ if (
170
+ inference_params.seqlen_offset == 0
171
+ or flash_attn_with_kvcache is None
172
+ ):
173
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
174
+ kv = self._update_kv_cache(kv, inference_params)
175
+ k, v = kv.unbind(dim=-3)
176
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
177
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
178
+ return F.scaled_dot_product_attention(
179
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
180
+ ).transpose(1, 2)
181
+ else:
182
+ batch = q.shape[0]
183
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
184
+ kv_cache = kv_cache[:batch]
185
+ cache_seqlens = (
186
+ inference_params.lengths_per_sample[:batch]
187
+ if inference_params.lengths_per_sample is not None
188
+ else inference_params.seqlen_offset
189
+ )
190
+ return flash_attn_with_kvcache(
191
+ q,
192
+ kv_cache[:, :, 0],
193
+ kv_cache[:, :, 1],
194
+ kv[:, :, 0],
195
+ kv[:, :, 1],
196
+ cache_seqlens=cache_seqlens,
197
+ softmax_scale=self.softmax_scale,
198
+ causal=self.causal,
199
+ )
200
+
201
+ def forward(self, x, inference_params=None):
202
+ """
203
+ Arguments:
204
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
205
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
206
+ is the is the sum of the sequence lengths in the batch.
207
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
208
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
209
+ """
210
+ if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
211
+ inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
212
+ x.shape[0], inference_params.max_seqlen, dtype=x.dtype
213
+ )
214
+ seqlen_offset = (
215
+ 0
216
+ if inference_params is None
217
+ else (
218
+ inference_params.lengths_per_sample
219
+ if inference_params.lengths_per_sample is not None
220
+ else inference_params.seqlen_offset
221
+ )
222
+ )
223
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
224
+ qkv = self.in_proj(x)
225
+ if self.mlp_dim > 0:
226
+ qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
227
+ x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
228
+ x_mlp = x_mlp_up * F.silu(x_mlp_gate)
229
+ if self.d_conv > 0:
230
+ # The inference code for conv1d is pretty messy, should clean it up
231
+ if (inference_params is None or inference_params.seqlen_offset == 0):
232
+ if causal_conv1d_fn is None:
233
+ qkv = rearrange(
234
+ self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
235
+ ).contiguous()
236
+ else:
237
+ qkv = causal_conv1d_fn(
238
+ qkv.transpose(1, 2),
239
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
240
+ self.conv1d.bias
241
+ ).transpose(1, 2)
242
+ if inference_params is not None:
243
+ _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
244
+ # If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
245
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
246
+ qkv_t = rearrange(qkv, "b l d -> b d l")
247
+ conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
248
+ else:
249
+ _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
250
+ assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
251
+ qkv = qkv.squeeze(1)
252
+ # Conv step
253
+ if causal_conv1d_update is None:
254
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
255
+ conv_state[:, :, -1] = qkv
256
+ qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
257
+ if self.conv1d.bias is not None:
258
+ qkv = qkv + self.conv1d.bias
259
+ else:
260
+ qkv = causal_conv1d_update(
261
+ qkv,
262
+ conv_state,
263
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
264
+ self.conv1d.bias
265
+ )
266
+ qkv = qkv.unsqueeze(1)
267
+ q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
268
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
269
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
270
+ if (
271
+ inference_params is None
272
+ or inference_params.seqlen_offset == 0
273
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
274
+ ):
275
+ if self.rotary_emb_dim > 0:
276
+ q, kv = self.rotary_emb(
277
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
278
+ )
279
+ if inference_params is None:
280
+ k, v = kv.unbind(dim=-3)
281
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
282
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
283
+ context = F.scaled_dot_product_attention(
284
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
285
+ ).transpose(1, 2)
286
+ else:
287
+ context = self._update_kvcache_attention(q, kv, inference_params)
288
+ else:
289
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
290
+ context = rearrange(context, "... h d -> ... (h d)")
291
+ if self.mlp_dim > 0:
292
+ context = torch.cat([context, x_mlp], dim=-1)
293
+ out = self.out_proj(context)
294
+ return out
mamba/build/lib/mamba_ssm/modules/mlp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class GatedMLP(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_features,
10
+ hidden_features=None,
11
+ out_features=None,
12
+ activation=F.silu,
13
+ bias=False,
14
+ multiple_of=128,
15
+ device=None,
16
+ dtype=None,
17
+ ):
18
+ factory_kwargs = {"device": device, "dtype": dtype}
19
+ super().__init__()
20
+ out_features = out_features if out_features is not None else in_features
21
+ hidden_features = (
22
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
23
+ )
24
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
25
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
26
+ self.activation = activation
27
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
28
+
29
+ def forward(self, x):
30
+ y = self.fc1(x)
31
+ y, gate = y.chunk(2, dim=-1)
32
+ y = y * self.activation(gate)
33
+ y = self.fc2(y)
34
+ return y
mamba/build/lib/mamba_ssm/modules/ssd_minimal.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Albert Gu and Tri Dao.
2
+ """Minimal implementation of SSD.
3
+
4
+ This is the same as Listing 1 from the paper.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
12
+
13
+
14
+ def segsum_unstable(x):
15
+ """Naive segment sum calculation."""
16
+ T = x.size(-1)
17
+ x_cumsum = torch.cumsum(x, dim=-1)
18
+ x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
19
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
20
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
21
+ return x_segsum
22
+
23
+ def segsum(x):
24
+ """More stable segment sum calculation."""
25
+ T = x.size(-1)
26
+ x = repeat(x, "... d -> ... d e", e=T)
27
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
28
+ x = x.masked_fill(~mask, 0)
29
+ x_segsum = torch.cumsum(x, dim=-2)
30
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
31
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
32
+ return x_segsum
33
+
34
+ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
35
+ """
36
+ Arguments:
37
+ X: (batch, length, n_heads, d_head)
38
+ A: (batch, length, n_heads)
39
+ B: (batch, length, n_heads, d_state)
40
+ C: (batch, length, n_heads, d_state)
41
+ Return:
42
+ Y: (batch, length, n_heads, d_head)
43
+ """
44
+ assert X.dtype == A.dtype == B.dtype == C.dtype
45
+ assert X.shape[1] % block_len == 0
46
+
47
+ # Rearrange into blocks/chunks
48
+ X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
49
+
50
+ A = rearrange(A, "b c l h -> b h c l")
51
+ A_cumsum = torch.cumsum(A, dim=-1)
52
+
53
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
54
+ L = torch.exp(segsum(A))
55
+ Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
56
+
57
+ # 2. Compute the state for each intra-chunk
58
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
59
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
60
+ states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
61
+
62
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
63
+ # (middle term of factorization of off-diag blocks; A terms)
64
+ if initial_states is None:
65
+ initial_states = torch.zeros_like(states[:, :1])
66
+ states = torch.cat([initial_states, states], dim=1)
67
+ decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
68
+ new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
69
+ states, final_state = new_states[:, :-1], new_states[:, -1]
70
+
71
+ # 4. Compute state -> output conversion per chunk
72
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
73
+ state_decay_out = torch.exp(A_cumsum)
74
+ Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
75
+
76
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
77
+ Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
78
+ return Y, final_state
79
+
80
+
81
+ # Simple test
82
+ def test_correctness():
83
+ torch.manual_seed(42)
84
+
85
+ ## Dimensions
86
+ # Denoted (B, T, Q, D, P) in the paper
87
+ batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
88
+ nheads = dim // headdim # (H) in the paper
89
+ ngroups = 1 # (G) in the paper
90
+ dstate = 64 # (N) in the paper
91
+ dtype = torch.float32
92
+ device = "cuda"
93
+
94
+ x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
95
+ dt = F.softplus(torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4).requires_grad_()
96
+ A = (-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))).requires_grad_()
97
+ B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
98
+ C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
99
+ D = torch.randn(nheads, dtype=dtype, device=device)
100
+
101
+ # Comparing fused version and minimal version
102
+ y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)
103
+ y_min, _ = ssd_minimal_discrete(x*dt.unsqueeze(-1), A*dt, B, C, chunk_size)
mamba/build/lib/mamba_ssm/ops/__init__.py ADDED
File without changes
mamba/build/lib/mamba_ssm/ops/selective_scan_interface.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.cuda.amp import custom_bwd, custom_fwd
6
+
7
+ from einops import rearrange, repeat
8
+
9
+ try:
10
+ from causal_conv1d import causal_conv1d_fn
11
+ import causal_conv1d_cuda
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+ causal_conv1d_cuda = None
15
+
16
+ import selective_scan_cuda
17
+
18
+
19
+ class SelectiveScanFn(torch.autograd.Function):
20
+
21
+ @staticmethod
22
+ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
23
+ return_last_state=False):
24
+ if u.stride(-1) != 1:
25
+ u = u.contiguous()
26
+ if delta.stride(-1) != 1:
27
+ delta = delta.contiguous()
28
+ if D is not None:
29
+ D = D.contiguous()
30
+ if B.stride(-1) != 1:
31
+ B = B.contiguous()
32
+ if C.stride(-1) != 1:
33
+ C = C.contiguous()
34
+ if z is not None and z.stride(-1) != 1:
35
+ z = z.contiguous()
36
+ if B.dim() == 3:
37
+ B = rearrange(B, "b dstate l -> b 1 dstate l")
38
+ ctx.squeeze_B = True
39
+ if C.dim() == 3:
40
+ C = rearrange(C, "b dstate l -> b 1 dstate l")
41
+ ctx.squeeze_C = True
42
+ out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
43
+ ctx.delta_softplus = delta_softplus
44
+ ctx.has_z = z is not None
45
+ last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
46
+ if not ctx.has_z:
47
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
48
+ return out if not return_last_state else (out, last_state)
49
+ else:
50
+ ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
51
+ out_z = rest[0]
52
+ return out_z if not return_last_state else (out_z, last_state)
53
+
54
+ @staticmethod
55
+ def backward(ctx, dout, *args):
56
+ if not ctx.has_z:
57
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
58
+ z = None
59
+ out = None
60
+ else:
61
+ u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
62
+ if dout.stride(-1) != 1:
63
+ dout = dout.contiguous()
64
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
65
+ # backward of selective_scan_cuda with the backward of chunk).
66
+ # Here we just pass in None and dz will be allocated in the C++ code.
67
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
68
+ u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
69
+ False # option to recompute out_z, not used here
70
+ )
71
+ dz = rest[0] if ctx.has_z else None
72
+ dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
73
+ dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
74
+ return (du, ddelta, dA, dB, dC,
75
+ dD if D is not None else None,
76
+ dz,
77
+ ddelta_bias if delta_bias is not None else None,
78
+ None,
79
+ None)
80
+
81
+
82
+ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
83
+ return_last_state=False):
84
+ """if return_last_state is True, returns (out, last_state)
85
+ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
86
+ not considered in the backward pass.
87
+ """
88
+ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
89
+
90
+
91
+ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
92
+ return_last_state=False):
93
+ """
94
+ u: r(B D L)
95
+ delta: r(B D L)
96
+ A: c(D N) or r(D N)
97
+ B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
98
+ C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
99
+ D: r(D)
100
+ z: r(B D L)
101
+ delta_bias: r(D), fp32
102
+
103
+ out: r(B D L)
104
+ last_state (optional): r(B D dstate) or c(B D dstate)
105
+ """
106
+ dtype_in = u.dtype
107
+ u = u.float()
108
+ delta = delta.float()
109
+ if delta_bias is not None:
110
+ delta = delta + delta_bias[..., None].float()
111
+ if delta_softplus:
112
+ delta = F.softplus(delta)
113
+ batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
114
+ is_variable_B = B.dim() >= 3
115
+ is_variable_C = C.dim() >= 3
116
+ if A.is_complex():
117
+ if is_variable_B:
118
+ B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
119
+ if is_variable_C:
120
+ C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
121
+ else:
122
+ B = B.float()
123
+ C = C.float()
124
+ x = A.new_zeros((batch, dim, dstate))
125
+ ys = []
126
+ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
127
+ if not is_variable_B:
128
+ deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
129
+ else:
130
+ if B.dim() == 3:
131
+ deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
132
+ else:
133
+ B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
134
+ deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
135
+ if is_variable_C and C.dim() == 4:
136
+ C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
137
+ last_state = None
138
+ for i in range(u.shape[2]):
139
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
140
+ if not is_variable_C:
141
+ y = torch.einsum('bdn,dn->bd', x, C)
142
+ else:
143
+ if C.dim() == 3:
144
+ y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
145
+ else:
146
+ y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
147
+ if i == u.shape[2] - 1:
148
+ last_state = x
149
+ if y.is_complex():
150
+ y = y.real * 2
151
+ ys.append(y)
152
+ y = torch.stack(ys, dim=2) # (batch dim L)
153
+ out = y if D is None else y + u * rearrange(D, "d -> d 1")
154
+ if z is not None:
155
+ out = out * F.silu(z)
156
+ out = out.to(dtype=dtype_in)
157
+ return out if not return_last_state else (out, last_state)
158
+
159
+
160
+ class MambaInnerFn(torch.autograd.Function):
161
+
162
+ @staticmethod
163
+ @custom_fwd
164
+ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
165
+ out_proj_weight, out_proj_bias,
166
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
167
+ C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
168
+ """
169
+ xz: (batch, dim, seqlen)
170
+ """
171
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
172
+ assert checkpoint_lvl in [0, 1]
173
+ L = xz.shape[-1]
174
+ delta_rank = delta_proj_weight.shape[1]
175
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
176
+ if torch.is_autocast_enabled():
177
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
178
+ delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
179
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
180
+ out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
181
+ if out_proj_bias is not None else None)
182
+ if xz.stride(-1) != 1:
183
+ xz = xz.contiguous()
184
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
185
+ x, z = xz.chunk(2, dim=1)
186
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
187
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
188
+ x, conv1d_weight, conv1d_bias, None, None, None, True
189
+ )
190
+ # We're being very careful here about the layout, to avoid extra transposes.
191
+ # We want delta to have d as the slowest moving dimension
192
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
193
+ x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
194
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
195
+ ctx.is_variable_B = B is None
196
+ ctx.is_variable_C = C is None
197
+ ctx.B_proj_bias_is_None = B_proj_bias is None
198
+ ctx.C_proj_bias_is_None = C_proj_bias is None
199
+ if B is None: # variable B
200
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
201
+ if B_proj_bias is not None:
202
+ B = B + B_proj_bias.to(dtype=B.dtype)
203
+ if not A.is_complex():
204
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
205
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
206
+ else:
207
+ B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
208
+ else:
209
+ if B.stride(-1) != 1:
210
+ B = B.contiguous()
211
+ if C is None: # variable C
212
+ C = x_dbl[:, -d_state:] # (bl dstate)
213
+ if C_proj_bias is not None:
214
+ C = C + C_proj_bias.to(dtype=C.dtype)
215
+ if not A.is_complex():
216
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
217
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
218
+ else:
219
+ C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
220
+ else:
221
+ if C.stride(-1) != 1:
222
+ C = C.contiguous()
223
+ if D is not None:
224
+ D = D.contiguous()
225
+ out, scan_intermediates, out_z = selective_scan_cuda.fwd(
226
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
227
+ )
228
+ ctx.delta_softplus = delta_softplus
229
+ ctx.out_proj_bias_is_None = out_proj_bias is None
230
+ ctx.checkpoint_lvl = checkpoint_lvl
231
+ if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
232
+ conv1d_out, delta = None, None
233
+ ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
234
+ delta_proj_weight, out_proj_weight, conv1d_out, delta,
235
+ A, B, C, D, delta_bias, scan_intermediates, out)
236
+ return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
237
+
238
+ @staticmethod
239
+ @custom_bwd
240
+ def backward(ctx, dout):
241
+ # dout: (batch, seqlen, dim)
242
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
243
+ (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
244
+ conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
245
+ L = xz.shape[-1]
246
+ delta_rank = delta_proj_weight.shape[1]
247
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
248
+ x, z = xz.chunk(2, dim=1)
249
+ if dout.stride(-1) != 1:
250
+ dout = dout.contiguous()
251
+ if ctx.checkpoint_lvl == 1:
252
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
253
+ x, conv1d_weight, conv1d_bias, None, None, None, True
254
+ )
255
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
256
+ "d (b l) -> b d l", l = L)
257
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
258
+ # backward of selective_scan_cuda with the backward of chunk).
259
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
260
+ dx, dz = dxz.chunk(2, dim=1)
261
+ dout = rearrange(dout, "b l e -> e (b l)")
262
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
263
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
264
+ conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
265
+ ctx.delta_softplus,
266
+ True # option to recompute out_z
267
+ )
268
+ dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
269
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
270
+ dD = dD if D is not None else None
271
+ dx_dbl = torch.empty_like(x_dbl)
272
+ dB_proj_bias = None
273
+ if ctx.is_variable_B:
274
+ if not A.is_complex():
275
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
276
+ else:
277
+ dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
278
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
279
+ dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
280
+ dB = None
281
+ dC_proj_bias = None
282
+ if ctx.is_variable_C:
283
+ if not A.is_complex():
284
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
285
+ else:
286
+ dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
287
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
288
+ dx_dbl[:, -d_state:] = dC # (bl d)
289
+ dC = None
290
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
291
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
292
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
293
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
294
+ dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
295
+ dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
296
+ dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
297
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
298
+ # backward of conv1d with the backward of chunk).
299
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
300
+ x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
301
+ )
302
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
303
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
304
+ return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
305
+ dout_proj_weight, dout_proj_bias,
306
+ dA, dB, dC, dD,
307
+ ddelta_bias if delta_bias is not None else None,
308
+ dB_proj_bias, dC_proj_bias, None)
309
+
310
+
311
+ def mamba_inner_fn(
312
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
313
+ out_proj_weight, out_proj_bias,
314
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
315
+ C_proj_bias=None, delta_softplus=True
316
+ ):
317
+ return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
318
+ out_proj_weight, out_proj_bias,
319
+ A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
320
+
321
+
322
+ def mamba_inner_ref(
323
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
324
+ out_proj_weight, out_proj_bias,
325
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
326
+ C_proj_bias=None, delta_softplus=True
327
+ ):
328
+ assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
329
+ L = xz.shape[-1]
330
+ delta_rank = delta_proj_weight.shape[1]
331
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
332
+ x, z = xz.chunk(2, dim=1)
333
+ x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
334
+ # We're being very careful here about the layout, to avoid extra transposes.
335
+ # We want delta to have d as the slowest moving dimension
336
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
337
+ x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
338
+ delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
339
+ delta = rearrange(delta, "d (b l) -> b d l", l=L)
340
+ if B is None: # variable B
341
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
342
+ if B_proj_bias is not None:
343
+ B = B + B_proj_bias.to(dtype=B.dtype)
344
+ if not A.is_complex():
345
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
346
+ else:
347
+ B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
348
+ if C is None: # variable B
349
+ C = x_dbl[:, -d_state:] # (bl d)
350
+ if C_proj_bias is not None:
351
+ C = C + C_proj_bias.to(dtype=C.dtype)
352
+ if not A.is_complex():
353
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
354
+ else:
355
+ C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
356
+ y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
357
+ return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
mamba/build/lib/mamba_ssm/ops/triton/__init__.py ADDED
File without changes
mamba/build/lib/mamba_ssm/ops/triton/k_activations.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import torch
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+
9
+ @triton.autotune(
10
+ configs=[
11
+ triton.Config({'BLOCK_N': 32}),
12
+ triton.Config({'BLOCK_N': 64}),
13
+ triton.Config({'BLOCK_N': 128}),
14
+ triton.Config({'BLOCK_N': 256}),
15
+ triton.Config({'BLOCK_N': 512}),
16
+ triton.Config({'BLOCK_N': 1024}),
17
+ ],
18
+ key=['ncols'],
19
+ )
20
+ @triton.jit
21
+ def _swiglu_fwd_kernel(
22
+ X,
23
+ Y,
24
+ OUT,
25
+ stride_x_row, # how much to increase the pointer when moving by 1 row
26
+ stride_y_row,
27
+ stride_out_row,
28
+ ncols,
29
+ BLOCK_N: tl.constexpr,
30
+ ):
31
+ # Map the program id to the row of X and Y it should compute.
32
+ row = tl.program_id(0)
33
+ start_col = tl.program_id(1) * BLOCK_N
34
+ X += row * stride_x_row
35
+ Y += row * stride_y_row
36
+ OUT += row * stride_out_row
37
+ cols = start_col + tl.arange(0, BLOCK_N)
38
+ x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
39
+ y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
40
+ out = x * tl.sigmoid(x) * y
41
+ tl.store(OUT + cols, out, mask=cols < ncols)
42
+
43
+
44
+ def _swiglu_fwd(xy, out=None):
45
+ if xy.stride(-1) != 1:
46
+ xy = xy.contiguous()
47
+ batch_shape = xy.shape[:-1]
48
+ xy = xy.reshape(-1, xy.shape[-1])
49
+ x, y = xy.chunk(2, dim=-1)
50
+ if out is None:
51
+ out = torch.empty_like(x)
52
+ else:
53
+ out = out.reshape(-1, out.shape[-1])
54
+ assert out.shape == x.shape
55
+ assert out.stride(-1) == 1
56
+ M, N = x.shape
57
+ grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
58
+ with torch.cuda.device(x.device.index):
59
+ _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
60
+ return out.reshape(*batch_shape, out.shape[-1])
61
+
62
+
63
+ @triton.autotune(
64
+ configs=[
65
+ triton.Config({'BLOCK_N': 32}),
66
+ triton.Config({'BLOCK_N': 64}),
67
+ triton.Config({'BLOCK_N': 128}),
68
+ triton.Config({'BLOCK_N': 256}),
69
+ triton.Config({'BLOCK_N': 512}),
70
+ triton.Config({'BLOCK_N': 1024}),
71
+ ],
72
+ key=['ncols'],
73
+ )
74
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
75
+ @triton.jit
76
+ def _swiglu_bwd_kernel(
77
+ X,
78
+ Y,
79
+ DOUT,
80
+ OUT,
81
+ DX,
82
+ DY,
83
+ stride_x_row, # how much to increase the pointer when moving by 1 row
84
+ stride_y_row,
85
+ stride_dout_row,
86
+ stride_out_row,
87
+ stride_dx_row,
88
+ stride_dy_row,
89
+ ncols,
90
+ BLOCK_N: tl.constexpr,
91
+ RECOMPUTE_OUTPUT: tl.constexpr,
92
+ ):
93
+ # Map the program id to the row of X and Y it should compute.
94
+ row = tl.program_id(0)
95
+ start_col = tl.program_id(1) * BLOCK_N
96
+ X += row * stride_x_row
97
+ Y += row * stride_y_row
98
+ DOUT += row * stride_dout_row
99
+ if RECOMPUTE_OUTPUT:
100
+ OUT += row * stride_out_row
101
+ DX += row * stride_dx_row
102
+ DY += row * stride_dy_row
103
+ cols = start_col + tl.arange(0, BLOCK_N)
104
+ x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
105
+ y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
106
+ dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
107
+ x_sigmoid = tl.sigmoid(x)
108
+ dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
109
+ dy = x * x_sigmoid * dout
110
+ tl.store(DX + cols, dx, mask=cols < ncols)
111
+ tl.store(DY + cols, dy, mask=cols < ncols)
112
+ if RECOMPUTE_OUTPUT:
113
+ out = x * x_sigmoid * y
114
+ tl.store(OUT + cols, out, mask=cols < ncols)
115
+
116
+
117
+ def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
118
+ if xy.stride(-1) != 1:
119
+ xy = xy.contiguous()
120
+ if dout.stride(-1) != 1:
121
+ dout = dout.contiguous()
122
+ batch_shape = xy.shape[:-1]
123
+ xy = xy.reshape(-1, xy.shape[-1])
124
+ x, y = xy.chunk(2, dim=-1)
125
+ dout = dout.reshape(-1, dout.shape[-1])
126
+ assert dout.shape == x.shape
127
+ if dxy is None:
128
+ dxy = torch.empty_like(xy)
129
+ else:
130
+ dxy = dxy.reshape(-1, dxy.shape[-1])
131
+ assert dxy.shape == xy.shape
132
+ dx, dy = dxy.chunk(2, dim=-1)
133
+ assert dx.stride(-1) == 1
134
+ assert dy.stride(-1) == 1
135
+ if recompute_output:
136
+ if out is None:
137
+ out = torch.empty_like(x)
138
+ else:
139
+ out = out.reshape(-1, out.shape[-1])
140
+ assert out.shape == x.shape
141
+ assert out.stride(-1) == 1
142
+ M, N = x.shape
143
+ grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
144
+ with torch.cuda.device(x.device.index):
145
+ _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
146
+ x.stride(0), y.stride(0), dout.stride(0),
147
+ out.stride(0) if recompute_output else 0,
148
+ dx.stride(0), dy.stride(0),
149
+ N)
150
+ if not recompute_output:
151
+ return dxy.reshape(*batch_shape, dxy.shape[-1])
152
+ else:
153
+ return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
154
+
155
+
156
+ class SwiGLU(torch.autograd.Function):
157
+
158
+ @staticmethod
159
+ def forward(ctx, xy):
160
+ ctx.save_for_backward(xy)
161
+ return _swiglu_fwd(xy)
162
+
163
+ @staticmethod
164
+ def backward(ctx, dout):
165
+ xy, = ctx.saved_tensors
166
+ return _swiglu_bwd(xy, dout)
167
+
168
+
169
+ swiglu = SwiGLU.apply
mamba/build/lib/mamba_ssm/ops/triton/layer_norm.py ADDED
@@ -0,0 +1,1113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Implement dropout + residual + layer_norm / rms_norm.
3
+
4
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
+
9
+ import math
10
+ import warnings
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch.cuda.amp import custom_fwd, custom_bwd
15
+
16
+ import triton
17
+ import triton.language as tl
18
+
19
+
20
+ def layer_norm_ref(
21
+ x,
22
+ weight,
23
+ bias,
24
+ residual=None,
25
+ x1=None,
26
+ weight1=None,
27
+ bias1=None,
28
+ eps=1e-6,
29
+ dropout_p=0.0,
30
+ rowscale=None,
31
+ prenorm=False,
32
+ dropout_mask=None,
33
+ dropout_mask1=None,
34
+ upcast=False,
35
+ ):
36
+ dtype = x.dtype
37
+ if upcast:
38
+ x = x.float()
39
+ weight = weight.float()
40
+ bias = bias.float() if bias is not None else None
41
+ residual = residual.float() if residual is not None else residual
42
+ x1 = x1.float() if x1 is not None else None
43
+ weight1 = weight1.float() if weight1 is not None else None
44
+ bias1 = bias1.float() if bias1 is not None else None
45
+ if x1 is not None:
46
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
47
+ if rowscale is not None:
48
+ x = x * rowscale[..., None]
49
+ if dropout_p > 0.0:
50
+ if dropout_mask is not None:
51
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
52
+ else:
53
+ x = F.dropout(x, p=dropout_p)
54
+ if x1 is not None:
55
+ if dropout_mask1 is not None:
56
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
57
+ else:
58
+ x1 = F.dropout(x1, p=dropout_p)
59
+ if x1 is not None:
60
+ x = x + x1
61
+ if residual is not None:
62
+ x = (x + residual).to(x.dtype)
63
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
64
+ dtype
65
+ )
66
+ if weight1 is None:
67
+ return out if not prenorm else (out, x)
68
+ else:
69
+ out1 = F.layer_norm(
70
+ x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
71
+ ).to(dtype)
72
+ return (out, out1) if not prenorm else (out, out1, x)
73
+
74
+
75
+ def rms_norm_ref(
76
+ x,
77
+ weight,
78
+ bias,
79
+ residual=None,
80
+ x1=None,
81
+ weight1=None,
82
+ bias1=None,
83
+ eps=1e-6,
84
+ dropout_p=0.0,
85
+ rowscale=None,
86
+ prenorm=False,
87
+ dropout_mask=None,
88
+ dropout_mask1=None,
89
+ upcast=False,
90
+ ):
91
+ dtype = x.dtype
92
+ if upcast:
93
+ x = x.float()
94
+ weight = weight.float()
95
+ bias = bias.float() if bias is not None else None
96
+ residual = residual.float() if residual is not None else residual
97
+ x1 = x1.float() if x1 is not None else None
98
+ weight1 = weight1.float() if weight1 is not None else None
99
+ bias1 = bias1.float() if bias1 is not None else None
100
+ if x1 is not None:
101
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
102
+ if rowscale is not None:
103
+ x = x * rowscale[..., None]
104
+ if dropout_p > 0.0:
105
+ if dropout_mask is not None:
106
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
107
+ else:
108
+ x = F.dropout(x, p=dropout_p)
109
+ if x1 is not None:
110
+ if dropout_mask1 is not None:
111
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
112
+ else:
113
+ x1 = F.dropout(x1, p=dropout_p)
114
+ if x1 is not None:
115
+ x = x + x1
116
+ if residual is not None:
117
+ x = (x + residual).to(x.dtype)
118
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
119
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
120
+ if weight1 is None:
121
+ return out if not prenorm else (out, x)
122
+ else:
123
+ out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
124
+ dtype
125
+ )
126
+ return (out, out1) if not prenorm else (out, out1, x)
127
+
128
+ def config_prune(configs):
129
+
130
+ if torch.version.hip:
131
+ try:
132
+ # set warp size based on gcn architecure
133
+ gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
134
+ if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
135
+ # radeon
136
+ warp_size = 32
137
+ else:
138
+ # instinct
139
+ warp_size = 64
140
+ except AttributeError as e:
141
+ # fall back to crude method to set warp size
142
+ device_name = torch.cuda.get_device_properties(0).name
143
+ if 'instinct' in device_name.lower():
144
+ warp_size = 64
145
+ else:
146
+ warp_size = 32
147
+ warnings.warn(f"{e}, warp size set to {warp_size} based on device name: {device_name}", UserWarning)
148
+
149
+ else:
150
+ # cuda
151
+ warp_size = 32
152
+
153
+ max_block_sz = 1024
154
+ max_num_warps = max_block_sz // warp_size
155
+ pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
156
+ return pruned_configs
157
+
158
+ configs_autotune = [
159
+ triton.Config({}, num_warps=1),
160
+ triton.Config({}, num_warps=2),
161
+ triton.Config({}, num_warps=4),
162
+ triton.Config({}, num_warps=8),
163
+ triton.Config({}, num_warps=16),
164
+ triton.Config({}, num_warps=32),
165
+ ]
166
+
167
+ pruned_configs_autotune = config_prune(configs_autotune)
168
+
169
+ @triton.autotune(
170
+ configs = pruned_configs_autotune,
171
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
172
+ )
173
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
174
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
175
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
176
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
177
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
178
+ @triton.jit
179
+ def _layer_norm_fwd_1pass_kernel(
180
+ X, # pointer to the input
181
+ Y, # pointer to the output
182
+ W, # pointer to the weights
183
+ B, # pointer to the biases
184
+ RESIDUAL, # pointer to the residual
185
+ X1,
186
+ W1,
187
+ B1,
188
+ Y1,
189
+ RESIDUAL_OUT, # pointer to the residual
190
+ ROWSCALE,
191
+ SEEDS, # Dropout seeds for each row
192
+ DROPOUT_MASK,
193
+ Mean, # pointer to the mean
194
+ Rstd, # pointer to the 1/std
195
+ stride_x_row, # how much to increase the pointer when moving by 1 row
196
+ stride_y_row,
197
+ stride_res_row,
198
+ stride_res_out_row,
199
+ stride_x1_row,
200
+ stride_y1_row,
201
+ M, # number of rows in X
202
+ N, # number of columns in X
203
+ eps, # epsilon to avoid division by zero
204
+ dropout_p, # Dropout probability
205
+ IS_RMS_NORM: tl.constexpr,
206
+ BLOCK_N: tl.constexpr,
207
+ HAS_RESIDUAL: tl.constexpr,
208
+ STORE_RESIDUAL_OUT: tl.constexpr,
209
+ HAS_BIAS: tl.constexpr,
210
+ HAS_DROPOUT: tl.constexpr,
211
+ STORE_DROPOUT_MASK: tl.constexpr,
212
+ HAS_ROWSCALE: tl.constexpr,
213
+ HAS_X1: tl.constexpr,
214
+ HAS_W1: tl.constexpr,
215
+ HAS_B1: tl.constexpr,
216
+ ):
217
+ # Map the program id to the row of X and Y it should compute.
218
+ row = tl.program_id(0)
219
+ X += row * stride_x_row
220
+ Y += row * stride_y_row
221
+ if HAS_RESIDUAL:
222
+ RESIDUAL += row * stride_res_row
223
+ if STORE_RESIDUAL_OUT:
224
+ RESIDUAL_OUT += row * stride_res_out_row
225
+ if HAS_X1:
226
+ X1 += row * stride_x1_row
227
+ if HAS_W1:
228
+ Y1 += row * stride_y1_row
229
+ # Compute mean and variance
230
+ cols = tl.arange(0, BLOCK_N)
231
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
232
+ if HAS_ROWSCALE:
233
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
234
+ x *= rowscale
235
+ if HAS_DROPOUT:
236
+ # Compute dropout mask
237
+ # 7 rounds is good enough, and reduces register pressure
238
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
239
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
240
+ if STORE_DROPOUT_MASK:
241
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
242
+ if HAS_X1:
243
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
244
+ if HAS_ROWSCALE:
245
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
246
+ x1 *= rowscale
247
+ if HAS_DROPOUT:
248
+ # Compute dropout mask
249
+ # 7 rounds is good enough, and reduces register pressure
250
+ keep_mask = (
251
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
252
+ )
253
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
254
+ if STORE_DROPOUT_MASK:
255
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
256
+ x += x1
257
+ if HAS_RESIDUAL:
258
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
259
+ x += residual
260
+ if STORE_RESIDUAL_OUT:
261
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
262
+ if not IS_RMS_NORM:
263
+ mean = tl.sum(x, axis=0) / N
264
+ tl.store(Mean + row, mean)
265
+ xbar = tl.where(cols < N, x - mean, 0.0)
266
+ var = tl.sum(xbar * xbar, axis=0) / N
267
+ else:
268
+ xbar = tl.where(cols < N, x, 0.0)
269
+ var = tl.sum(xbar * xbar, axis=0) / N
270
+ rstd = 1 / tl.sqrt(var + eps)
271
+ tl.store(Rstd + row, rstd)
272
+ # Normalize and apply linear transformation
273
+ mask = cols < N
274
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
275
+ if HAS_BIAS:
276
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
277
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
278
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
279
+ # Write output
280
+ tl.store(Y + cols, y, mask=mask)
281
+ if HAS_W1:
282
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
283
+ if HAS_B1:
284
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
285
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
286
+ tl.store(Y1 + cols, y1, mask=mask)
287
+
288
+
289
+ def _layer_norm_fwd(
290
+ x,
291
+ weight,
292
+ bias,
293
+ eps,
294
+ residual=None,
295
+ x1=None,
296
+ weight1=None,
297
+ bias1=None,
298
+ dropout_p=0.0,
299
+ rowscale=None,
300
+ out_dtype=None,
301
+ residual_dtype=None,
302
+ is_rms_norm=False,
303
+ return_dropout_mask=False,
304
+ ):
305
+ if residual is not None:
306
+ residual_dtype = residual.dtype
307
+ M, N = x.shape
308
+ assert x.stride(-1) == 1
309
+ if residual is not None:
310
+ assert residual.stride(-1) == 1
311
+ assert residual.shape == (M, N)
312
+ assert weight.shape == (N,)
313
+ assert weight.stride(-1) == 1
314
+ if bias is not None:
315
+ assert bias.stride(-1) == 1
316
+ assert bias.shape == (N,)
317
+ if x1 is not None:
318
+ assert x1.shape == x.shape
319
+ assert rowscale is None
320
+ assert x1.stride(-1) == 1
321
+ if weight1 is not None:
322
+ assert weight1.shape == (N,)
323
+ assert weight1.stride(-1) == 1
324
+ if bias1 is not None:
325
+ assert bias1.shape == (N,)
326
+ assert bias1.stride(-1) == 1
327
+ if rowscale is not None:
328
+ assert rowscale.is_contiguous()
329
+ assert rowscale.shape == (M,)
330
+ # allocate output
331
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
332
+ assert y.stride(-1) == 1
333
+ if weight1 is not None:
334
+ y1 = torch.empty_like(y)
335
+ assert y1.stride(-1) == 1
336
+ else:
337
+ y1 = None
338
+ if (
339
+ residual is not None
340
+ or (residual_dtype is not None and residual_dtype != x.dtype)
341
+ or dropout_p > 0.0
342
+ or rowscale is not None
343
+ or x1 is not None
344
+ ):
345
+ residual_out = torch.empty(
346
+ M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
347
+ )
348
+ assert residual_out.stride(-1) == 1
349
+ else:
350
+ residual_out = None
351
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
352
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
353
+ if dropout_p > 0.0:
354
+ seeds = torch.randint(
355
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
356
+ )
357
+ else:
358
+ seeds = None
359
+ if return_dropout_mask and dropout_p > 0.0:
360
+ dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
361
+ else:
362
+ dropout_mask = None
363
+ # Less than 64KB per feature: enqueue fused kernel
364
+ MAX_FUSED_SIZE = 65536 // x.element_size()
365
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
366
+ if N > BLOCK_N:
367
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
368
+ with torch.cuda.device(x.device.index):
369
+ _layer_norm_fwd_1pass_kernel[(M,)](
370
+ x,
371
+ y,
372
+ weight,
373
+ bias,
374
+ residual,
375
+ x1,
376
+ weight1,
377
+ bias1,
378
+ y1,
379
+ residual_out,
380
+ rowscale,
381
+ seeds,
382
+ dropout_mask,
383
+ mean,
384
+ rstd,
385
+ x.stride(0),
386
+ y.stride(0),
387
+ residual.stride(0) if residual is not None else 0,
388
+ residual_out.stride(0) if residual_out is not None else 0,
389
+ x1.stride(0) if x1 is not None else 0,
390
+ y1.stride(0) if y1 is not None else 0,
391
+ M,
392
+ N,
393
+ eps,
394
+ dropout_p,
395
+ is_rms_norm,
396
+ BLOCK_N,
397
+ residual is not None,
398
+ residual_out is not None,
399
+ bias is not None,
400
+ dropout_p > 0.0,
401
+ dropout_mask is not None,
402
+ rowscale is not None,
403
+ )
404
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
405
+ if dropout_mask is not None and x1 is not None:
406
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
407
+ else:
408
+ dropout_mask1 = None
409
+ return (
410
+ y,
411
+ y1,
412
+ mean,
413
+ rstd,
414
+ residual_out if residual_out is not None else x,
415
+ seeds,
416
+ dropout_mask,
417
+ dropout_mask1,
418
+ )
419
+
420
+
421
+ @triton.autotune(
422
+ configs=pruned_configs_autotune,
423
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
424
+ )
425
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
426
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
427
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
428
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
429
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
430
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
431
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
432
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
433
+ @triton.jit
434
+ def _layer_norm_bwd_kernel(
435
+ X, # pointer to the input
436
+ W, # pointer to the weights
437
+ B, # pointer to the biases
438
+ Y, # pointer to the output to be recomputed
439
+ DY, # pointer to the output gradient
440
+ DX, # pointer to the input gradient
441
+ DW, # pointer to the partial sum of weights gradient
442
+ DB, # pointer to the partial sum of biases gradient
443
+ DRESIDUAL,
444
+ W1,
445
+ DY1,
446
+ DX1,
447
+ DW1,
448
+ DB1,
449
+ DRESIDUAL_IN,
450
+ ROWSCALE,
451
+ SEEDS,
452
+ Mean, # pointer to the mean
453
+ Rstd, # pointer to the 1/std
454
+ stride_x_row, # how much to increase the pointer when moving by 1 row
455
+ stride_y_row,
456
+ stride_dy_row,
457
+ stride_dx_row,
458
+ stride_dres_row,
459
+ stride_dy1_row,
460
+ stride_dx1_row,
461
+ stride_dres_in_row,
462
+ M, # number of rows in X
463
+ N, # number of columns in X
464
+ eps, # epsilon to avoid division by zero
465
+ dropout_p,
466
+ rows_per_program,
467
+ IS_RMS_NORM: tl.constexpr,
468
+ BLOCK_N: tl.constexpr,
469
+ HAS_DRESIDUAL: tl.constexpr,
470
+ STORE_DRESIDUAL: tl.constexpr,
471
+ HAS_BIAS: tl.constexpr,
472
+ HAS_DROPOUT: tl.constexpr,
473
+ HAS_ROWSCALE: tl.constexpr,
474
+ HAS_DY1: tl.constexpr,
475
+ HAS_DX1: tl.constexpr,
476
+ HAS_B1: tl.constexpr,
477
+ RECOMPUTE_OUTPUT: tl.constexpr,
478
+ ):
479
+ # Map the program id to the elements of X, DX, and DY it should compute.
480
+ row_block_id = tl.program_id(0)
481
+ row_start = row_block_id * rows_per_program
482
+ # Do not early exit if row_start >= M, because we need to write DW and DB
483
+ cols = tl.arange(0, BLOCK_N)
484
+ mask = cols < N
485
+ X += row_start * stride_x_row
486
+ if HAS_DRESIDUAL:
487
+ DRESIDUAL += row_start * stride_dres_row
488
+ if STORE_DRESIDUAL:
489
+ DRESIDUAL_IN += row_start * stride_dres_in_row
490
+ DY += row_start * stride_dy_row
491
+ DX += row_start * stride_dx_row
492
+ if HAS_DY1:
493
+ DY1 += row_start * stride_dy1_row
494
+ if HAS_DX1:
495
+ DX1 += row_start * stride_dx1_row
496
+ if RECOMPUTE_OUTPUT:
497
+ Y += row_start * stride_y_row
498
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
499
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
500
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
501
+ if HAS_DY1:
502
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
503
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
504
+ if HAS_BIAS:
505
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
506
+ if HAS_DY1:
507
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
508
+ if HAS_B1:
509
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
510
+ row_end = min((row_block_id + 1) * rows_per_program, M)
511
+ for row in range(row_start, row_end):
512
+ # Load data to SRAM
513
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
514
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
515
+ if HAS_DY1:
516
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
517
+ if not IS_RMS_NORM:
518
+ mean = tl.load(Mean + row)
519
+ rstd = tl.load(Rstd + row)
520
+ # Compute dx
521
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
522
+ xhat = tl.where(mask, xhat, 0.0)
523
+ if RECOMPUTE_OUTPUT:
524
+ y = xhat * w + b if HAS_BIAS else xhat * w
525
+ tl.store(Y + cols, y, mask=mask)
526
+ wdy = w * dy
527
+ dw += dy * xhat
528
+ if HAS_BIAS:
529
+ db += dy
530
+ if HAS_DY1:
531
+ wdy += w1 * dy1
532
+ dw1 += dy1 * xhat
533
+ if HAS_B1:
534
+ db1 += dy1
535
+ if not IS_RMS_NORM:
536
+ c1 = tl.sum(xhat * wdy, axis=0) / N
537
+ c2 = tl.sum(wdy, axis=0) / N
538
+ dx = (wdy - (xhat * c1 + c2)) * rstd
539
+ else:
540
+ c1 = tl.sum(xhat * wdy, axis=0) / N
541
+ dx = (wdy - xhat * c1) * rstd
542
+ if HAS_DRESIDUAL:
543
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
544
+ dx += dres
545
+ # Write dx
546
+ if STORE_DRESIDUAL:
547
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
548
+ if HAS_DX1:
549
+ if HAS_DROPOUT:
550
+ keep_mask = (
551
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
552
+ )
553
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
554
+ else:
555
+ dx1 = dx
556
+ tl.store(DX1 + cols, dx1, mask=mask)
557
+ if HAS_DROPOUT:
558
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
559
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
560
+ if HAS_ROWSCALE:
561
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
562
+ dx *= rowscale
563
+ tl.store(DX + cols, dx, mask=mask)
564
+
565
+ X += stride_x_row
566
+ if HAS_DRESIDUAL:
567
+ DRESIDUAL += stride_dres_row
568
+ if STORE_DRESIDUAL:
569
+ DRESIDUAL_IN += stride_dres_in_row
570
+ if RECOMPUTE_OUTPUT:
571
+ Y += stride_y_row
572
+ DY += stride_dy_row
573
+ DX += stride_dx_row
574
+ if HAS_DY1:
575
+ DY1 += stride_dy1_row
576
+ if HAS_DX1:
577
+ DX1 += stride_dx1_row
578
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
579
+ if HAS_BIAS:
580
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
581
+ if HAS_DY1:
582
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
583
+ if HAS_B1:
584
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
585
+
586
+
587
+ def _layer_norm_bwd(
588
+ dy,
589
+ x,
590
+ weight,
591
+ bias,
592
+ eps,
593
+ mean,
594
+ rstd,
595
+ dresidual=None,
596
+ dy1=None,
597
+ weight1=None,
598
+ bias1=None,
599
+ seeds=None,
600
+ dropout_p=0.0,
601
+ rowscale=None,
602
+ has_residual=False,
603
+ has_x1=False,
604
+ is_rms_norm=False,
605
+ x_dtype=None,
606
+ recompute_output=False,
607
+ ):
608
+ M, N = x.shape
609
+ assert x.stride(-1) == 1
610
+ assert dy.stride(-1) == 1
611
+ assert dy.shape == (M, N)
612
+ if dresidual is not None:
613
+ assert dresidual.stride(-1) == 1
614
+ assert dresidual.shape == (M, N)
615
+ assert weight.shape == (N,)
616
+ assert weight.stride(-1) == 1
617
+ if bias is not None:
618
+ assert bias.stride(-1) == 1
619
+ assert bias.shape == (N,)
620
+ if dy1 is not None:
621
+ assert weight1 is not None
622
+ assert dy1.shape == dy.shape
623
+ assert dy1.stride(-1) == 1
624
+ if weight1 is not None:
625
+ assert weight1.shape == (N,)
626
+ assert weight1.stride(-1) == 1
627
+ if bias1 is not None:
628
+ assert bias1.shape == (N,)
629
+ assert bias1.stride(-1) == 1
630
+ if seeds is not None:
631
+ assert seeds.is_contiguous()
632
+ assert seeds.shape == (M if not has_x1 else M * 2,)
633
+ if rowscale is not None:
634
+ assert rowscale.is_contiguous()
635
+ assert rowscale.shape == (M,)
636
+ # allocate output
637
+ dx = (
638
+ torch.empty_like(x)
639
+ if x_dtype is None
640
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
641
+ )
642
+ dresidual_in = (
643
+ torch.empty_like(x)
644
+ if has_residual
645
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
646
+ else None
647
+ )
648
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
649
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
650
+ if recompute_output:
651
+ assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
652
+
653
+ # Less than 64KB per feature: enqueue fused kernel
654
+ MAX_FUSED_SIZE = 65536 // x.element_size()
655
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
656
+ if N > BLOCK_N:
657
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
658
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
659
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
660
+ _db = (
661
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
662
+ if bias is not None
663
+ else None
664
+ )
665
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
666
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
667
+ rows_per_program = math.ceil(M / sm_count)
668
+ grid = (sm_count,)
669
+ with torch.cuda.device(x.device.index):
670
+ _layer_norm_bwd_kernel[grid](
671
+ x,
672
+ weight,
673
+ bias,
674
+ y,
675
+ dy,
676
+ dx,
677
+ _dw,
678
+ _db,
679
+ dresidual,
680
+ weight1,
681
+ dy1,
682
+ dx1,
683
+ _dw1,
684
+ _db1,
685
+ dresidual_in,
686
+ rowscale,
687
+ seeds,
688
+ mean,
689
+ rstd,
690
+ x.stride(0),
691
+ 0 if not recompute_output else y.stride(0),
692
+ dy.stride(0),
693
+ dx.stride(0),
694
+ dresidual.stride(0) if dresidual is not None else 0,
695
+ dy1.stride(0) if dy1 is not None else 0,
696
+ dx1.stride(0) if dx1 is not None else 0,
697
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
698
+ M,
699
+ N,
700
+ eps,
701
+ dropout_p,
702
+ rows_per_program,
703
+ is_rms_norm,
704
+ BLOCK_N,
705
+ dresidual is not None,
706
+ dresidual_in is not None,
707
+ bias is not None,
708
+ dropout_p > 0.0,
709
+ )
710
+ dw = _dw.sum(0).to(weight.dtype)
711
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
712
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
713
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
714
+ # Don't need to compute dresidual_in separately in this case
715
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
716
+ dresidual_in = dx
717
+ if has_x1 and dropout_p == 0.0:
718
+ dx1 = dx
719
+ return (
720
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
721
+ if not recompute_output
722
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
723
+ )
724
+
725
+
726
+ class LayerNormFn(torch.autograd.Function):
727
+ @staticmethod
728
+ def forward(
729
+ ctx,
730
+ x,
731
+ weight,
732
+ bias,
733
+ residual=None,
734
+ x1=None,
735
+ weight1=None,
736
+ bias1=None,
737
+ eps=1e-6,
738
+ dropout_p=0.0,
739
+ rowscale=None,
740
+ prenorm=False,
741
+ residual_in_fp32=False,
742
+ is_rms_norm=False,
743
+ return_dropout_mask=False,
744
+ ):
745
+ x_shape_og = x.shape
746
+ # reshape input data into 2D tensor
747
+ x = x.reshape(-1, x.shape[-1])
748
+ if x.stride(-1) != 1:
749
+ x = x.contiguous()
750
+ if residual is not None:
751
+ assert residual.shape == x_shape_og
752
+ residual = residual.reshape(-1, residual.shape[-1])
753
+ if residual.stride(-1) != 1:
754
+ residual = residual.contiguous()
755
+ if x1 is not None:
756
+ assert x1.shape == x_shape_og
757
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
758
+ x1 = x1.reshape(-1, x1.shape[-1])
759
+ if x1.stride(-1) != 1:
760
+ x1 = x1.contiguous()
761
+ weight = weight.contiguous()
762
+ if bias is not None:
763
+ bias = bias.contiguous()
764
+ if weight1 is not None:
765
+ weight1 = weight1.contiguous()
766
+ if bias1 is not None:
767
+ bias1 = bias1.contiguous()
768
+ if rowscale is not None:
769
+ rowscale = rowscale.reshape(-1).contiguous()
770
+ residual_dtype = (
771
+ residual.dtype
772
+ if residual is not None
773
+ else (torch.float32 if residual_in_fp32 else None)
774
+ )
775
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
776
+ x,
777
+ weight,
778
+ bias,
779
+ eps,
780
+ residual,
781
+ x1,
782
+ weight1,
783
+ bias1,
784
+ dropout_p=dropout_p,
785
+ rowscale=rowscale,
786
+ residual_dtype=residual_dtype,
787
+ is_rms_norm=is_rms_norm,
788
+ return_dropout_mask=return_dropout_mask,
789
+ )
790
+ ctx.save_for_backward(
791
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
792
+ )
793
+ ctx.x_shape_og = x_shape_og
794
+ ctx.eps = eps
795
+ ctx.dropout_p = dropout_p
796
+ ctx.is_rms_norm = is_rms_norm
797
+ ctx.has_residual = residual is not None
798
+ ctx.has_x1 = x1 is not None
799
+ ctx.prenorm = prenorm
800
+ ctx.x_dtype = x.dtype
801
+ y = y.reshape(x_shape_og)
802
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
803
+ residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
804
+ dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
805
+ dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
806
+ if not return_dropout_mask:
807
+ if weight1 is None:
808
+ return y if not prenorm else (y, residual_out)
809
+ else:
810
+ return (y, y1) if not prenorm else (y, y1, residual_out)
811
+ else:
812
+ if weight1 is None:
813
+ return (
814
+ (y, dropout_mask, dropout_mask1)
815
+ if not prenorm
816
+ else (y, residual_out, dropout_mask, dropout_mask1)
817
+ )
818
+ else:
819
+ return (
820
+ (y, y1, dropout_mask, dropout_mask1)
821
+ if not prenorm
822
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
823
+ )
824
+
825
+ @staticmethod
826
+ def backward(ctx, dy, *args):
827
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
828
+ dy = dy.reshape(-1, dy.shape[-1])
829
+ if dy.stride(-1) != 1:
830
+ dy = dy.contiguous()
831
+ assert dy.shape == x.shape
832
+ if weight1 is not None:
833
+ dy1, args = args[0], args[1:]
834
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
835
+ if dy1.stride(-1) != 1:
836
+ dy1 = dy1.contiguous()
837
+ assert dy1.shape == x.shape
838
+ else:
839
+ dy1 = None
840
+ if ctx.prenorm:
841
+ dresidual = args[0]
842
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
843
+ if dresidual.stride(-1) != 1:
844
+ dresidual = dresidual.contiguous()
845
+ assert dresidual.shape == x.shape
846
+ else:
847
+ dresidual = None
848
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
849
+ dy,
850
+ x,
851
+ weight,
852
+ bias,
853
+ ctx.eps,
854
+ mean,
855
+ rstd,
856
+ dresidual,
857
+ dy1,
858
+ weight1,
859
+ bias1,
860
+ seeds,
861
+ ctx.dropout_p,
862
+ rowscale,
863
+ ctx.has_residual,
864
+ ctx.has_x1,
865
+ ctx.is_rms_norm,
866
+ x_dtype=ctx.x_dtype,
867
+ )
868
+ return (
869
+ dx.reshape(ctx.x_shape_og),
870
+ dw,
871
+ db,
872
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
873
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
874
+ dw1,
875
+ db1,
876
+ None,
877
+ None,
878
+ None,
879
+ None,
880
+ None,
881
+ None,
882
+ None,
883
+ )
884
+
885
+
886
+ def layer_norm_fn(
887
+ x,
888
+ weight,
889
+ bias,
890
+ residual=None,
891
+ x1=None,
892
+ weight1=None,
893
+ bias1=None,
894
+ eps=1e-6,
895
+ dropout_p=0.0,
896
+ rowscale=None,
897
+ prenorm=False,
898
+ residual_in_fp32=False,
899
+ is_rms_norm=False,
900
+ return_dropout_mask=False,
901
+ ):
902
+ return LayerNormFn.apply(
903
+ x,
904
+ weight,
905
+ bias,
906
+ residual,
907
+ x1,
908
+ weight1,
909
+ bias1,
910
+ eps,
911
+ dropout_p,
912
+ rowscale,
913
+ prenorm,
914
+ residual_in_fp32,
915
+ is_rms_norm,
916
+ return_dropout_mask,
917
+ )
918
+
919
+
920
+ def rms_norm_fn(
921
+ x,
922
+ weight,
923
+ bias,
924
+ residual=None,
925
+ x1=None,
926
+ weight1=None,
927
+ bias1=None,
928
+ eps=1e-6,
929
+ dropout_p=0.0,
930
+ rowscale=None,
931
+ prenorm=False,
932
+ residual_in_fp32=False,
933
+ return_dropout_mask=False,
934
+ ):
935
+ return LayerNormFn.apply(
936
+ x,
937
+ weight,
938
+ bias,
939
+ residual,
940
+ x1,
941
+ weight1,
942
+ bias1,
943
+ eps,
944
+ dropout_p,
945
+ rowscale,
946
+ prenorm,
947
+ residual_in_fp32,
948
+ True,
949
+ return_dropout_mask,
950
+ )
951
+
952
+
953
+ class RMSNorm(torch.nn.Module):
954
+
955
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
956
+ factory_kwargs = {"device": device, "dtype": dtype}
957
+ super().__init__()
958
+ self.eps = eps
959
+ if dropout_p > 0.0:
960
+ self.drop = torch.nn.Dropout(dropout_p)
961
+ else:
962
+ self.drop = None
963
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
964
+ self.register_parameter("bias", None)
965
+ self.reset_parameters()
966
+
967
+ def reset_parameters(self):
968
+ torch.nn.init.ones_(self.weight)
969
+
970
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
971
+ return rms_norm_fn(
972
+ x,
973
+ self.weight,
974
+ self.bias,
975
+ residual=residual,
976
+ eps=self.eps,
977
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
978
+ prenorm=prenorm,
979
+ residual_in_fp32=residual_in_fp32,
980
+ )
981
+
982
+
983
+ class LayerNormLinearFn(torch.autograd.Function):
984
+ @staticmethod
985
+ @custom_fwd
986
+ def forward(
987
+ ctx,
988
+ x,
989
+ norm_weight,
990
+ norm_bias,
991
+ linear_weight,
992
+ linear_bias,
993
+ residual=None,
994
+ eps=1e-6,
995
+ prenorm=False,
996
+ residual_in_fp32=False,
997
+ is_rms_norm=False,
998
+ ):
999
+ x_shape_og = x.shape
1000
+ # reshape input data into 2D tensor
1001
+ x = x.reshape(-1, x.shape[-1])
1002
+ if x.stride(-1) != 1:
1003
+ x = x.contiguous()
1004
+ if residual is not None:
1005
+ assert residual.shape == x_shape_og
1006
+ residual = residual.reshape(-1, residual.shape[-1])
1007
+ if residual.stride(-1) != 1:
1008
+ residual = residual.contiguous()
1009
+ norm_weight = norm_weight.contiguous()
1010
+ if norm_bias is not None:
1011
+ norm_bias = norm_bias.contiguous()
1012
+ residual_dtype = (
1013
+ residual.dtype
1014
+ if residual is not None
1015
+ else (torch.float32 if residual_in_fp32 else None)
1016
+ )
1017
+ y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
1018
+ x,
1019
+ norm_weight,
1020
+ norm_bias,
1021
+ eps,
1022
+ residual,
1023
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
1024
+ residual_dtype=residual_dtype,
1025
+ is_rms_norm=is_rms_norm,
1026
+ )
1027
+ y = y.reshape(x_shape_og)
1028
+ dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1029
+ linear_weight = linear_weight.to(dtype)
1030
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1031
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1032
+ # We don't store y, will be recomputed in the backward pass to save memory
1033
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
1034
+ ctx.x_shape_og = x_shape_og
1035
+ ctx.eps = eps
1036
+ ctx.is_rms_norm = is_rms_norm
1037
+ ctx.has_residual = residual is not None
1038
+ ctx.prenorm = prenorm
1039
+ ctx.x_dtype = x.dtype
1040
+ ctx.linear_bias_is_none = linear_bias is None
1041
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1042
+
1043
+ @staticmethod
1044
+ @custom_bwd
1045
+ def backward(ctx, dout, *args):
1046
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1047
+ dout = dout.reshape(-1, dout.shape[-1])
1048
+ dy = F.linear(dout, linear_weight.t())
1049
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1050
+ if dy.stride(-1) != 1:
1051
+ dy = dy.contiguous()
1052
+ assert dy.shape == x.shape
1053
+ if ctx.prenorm:
1054
+ dresidual = args[0]
1055
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1056
+ if dresidual.stride(-1) != 1:
1057
+ dresidual = dresidual.contiguous()
1058
+ assert dresidual.shape == x.shape
1059
+ else:
1060
+ dresidual = None
1061
+ dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
1062
+ dy,
1063
+ x,
1064
+ norm_weight,
1065
+ norm_bias,
1066
+ ctx.eps,
1067
+ mean,
1068
+ rstd,
1069
+ dresidual=dresidual,
1070
+ has_residual=ctx.has_residual,
1071
+ is_rms_norm=ctx.is_rms_norm,
1072
+ x_dtype=ctx.x_dtype,
1073
+ recompute_output=True,
1074
+ )
1075
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
1076
+ return (
1077
+ dx.reshape(ctx.x_shape_og),
1078
+ dnorm_weight,
1079
+ dnorm_bias,
1080
+ dlinear_weight,
1081
+ dlinear_bias,
1082
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
1083
+ None,
1084
+ None,
1085
+ None,
1086
+ None,
1087
+ )
1088
+
1089
+
1090
+ def layer_norm_linear_fn(
1091
+ x,
1092
+ norm_weight,
1093
+ norm_bias,
1094
+ linear_weight,
1095
+ linear_bias,
1096
+ residual=None,
1097
+ eps=1e-6,
1098
+ prenorm=False,
1099
+ residual_in_fp32=False,
1100
+ is_rms_norm=False,
1101
+ ):
1102
+ return LayerNormLinearFn.apply(
1103
+ x,
1104
+ norm_weight,
1105
+ norm_bias,
1106
+ linear_weight,
1107
+ linear_bias,
1108
+ residual,
1109
+ eps,
1110
+ prenorm,
1111
+ residual_in_fp32,
1112
+ is_rms_norm,
1113
+ )
mamba/build/lib/mamba_ssm/ops/triton/layernorm_gated.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
3
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
4
+ # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
5
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ from einops import rearrange
16
+
17
+
18
+ def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
19
+ dtype = x.dtype
20
+ N = x.shape[-1]
21
+ weight = weight.float()
22
+ bias = bias.float() if bias is not None else None
23
+ if upcast:
24
+ x = x.float()
25
+ z = z.float() if z is not None else z
26
+ if z is not None and not norm_before_gate:
27
+ x = x * F.silu(z)
28
+ if group_size is None:
29
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
30
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
31
+ else:
32
+ x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
33
+ rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
34
+ out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
35
+ if bias is not None:
36
+ out = out + bias
37
+ if z is not None and norm_before_gate:
38
+ out *= F.silu(z)
39
+ return out.to(dtype)
40
+
41
+
42
+ @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
43
+ @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
44
+ @triton.jit
45
+ def _layer_norm_fwd_1pass_kernel(
46
+ X, # pointer to the input
47
+ Y, # pointer to the output
48
+ W, # pointer to the weights
49
+ B, # pointer to the biases
50
+ Z, # pointer to the other branch
51
+ Mean, # pointer to the mean
52
+ Rstd, # pointer to the 1/std
53
+ stride_x_row, # how much to increase the pointer when moving by 1 row
54
+ stride_y_row,
55
+ stride_z_row,
56
+ M, # number of rows in X
57
+ N, # number of columns in X
58
+ eps, # epsilon to avoid division by zero
59
+ BLOCK_N: tl.constexpr,
60
+ HAS_BIAS: tl.constexpr,
61
+ HAS_Z: tl.constexpr,
62
+ NORM_BEFORE_GATE: tl.constexpr,
63
+ IS_RMS_NORM: tl.constexpr,
64
+ ):
65
+ # Map the program id to the row of X and Y it should compute.
66
+ row = tl.program_id(0)
67
+ group = tl.program_id(1)
68
+ X += row * stride_x_row + group * N
69
+ Y += row * stride_y_row + group * N
70
+ if HAS_Z:
71
+ Z += row * stride_z_row + group * N
72
+ if not IS_RMS_NORM:
73
+ Mean += group * M
74
+ Rstd += group * M
75
+ W += group * N
76
+ if HAS_BIAS:
77
+ B += group * N
78
+ # Compute mean and variance
79
+ cols = tl.arange(0, BLOCK_N)
80
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
81
+ if HAS_Z and not NORM_BEFORE_GATE:
82
+ z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
83
+ x *= z * tl.sigmoid(z)
84
+ if not IS_RMS_NORM:
85
+ mean = tl.sum(x, axis=0) / N
86
+ tl.store(Mean + row, mean)
87
+ xbar = tl.where(cols < N, x - mean, 0.)
88
+ var = tl.sum(xbar * xbar, axis=0) / N
89
+ else:
90
+ xbar = tl.where(cols < N, x, 0.)
91
+ var = tl.sum(xbar * xbar, axis=0) / N
92
+ rstd = 1 / tl.sqrt(var + eps)
93
+ tl.store(Rstd + row, rstd)
94
+ # Normalize and apply linear transformation
95
+ mask = cols < N
96
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
97
+ if HAS_BIAS:
98
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
99
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
100
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
101
+ if HAS_Z and NORM_BEFORE_GATE:
102
+ z = tl.load(Z + cols, mask=mask).to(tl.float32)
103
+ y *= z * tl.sigmoid(z)
104
+ # Write output
105
+ tl.store(Y + cols, y, mask=mask)
106
+
107
+
108
+ def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):
109
+ M, N = x.shape
110
+ if group_size is None:
111
+ group_size = N
112
+ assert N % group_size == 0
113
+ ngroups = N // group_size
114
+ assert x.stride(-1) == 1
115
+ if z is not None:
116
+ assert z.stride(-1) == 1
117
+ assert z.shape == (M, N)
118
+ assert weight.shape == (N,)
119
+ assert weight.stride(-1) == 1
120
+ if bias is not None:
121
+ assert bias.stride(-1) == 1
122
+ assert bias.shape == (N,)
123
+ # allocate output
124
+ if out is not None:
125
+ assert out.shape == x.shape
126
+ else:
127
+ out = torch.empty_like(x)
128
+ assert out.stride(-1) == 1
129
+ mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None
130
+ rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
131
+ # Less than 64KB per feature: enqueue fused kernel
132
+ MAX_FUSED_SIZE = 65536 // x.element_size()
133
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
134
+ if group_size > BLOCK_N:
135
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
136
+ # heuristics for number of warps
137
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
138
+ grid = (M, ngroups)
139
+ with torch.cuda.device(x.device.index):
140
+ _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,
141
+ x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,
142
+ M, group_size, eps,
143
+ BLOCK_N=BLOCK_N,
144
+ NORM_BEFORE_GATE=norm_before_gate,
145
+ IS_RMS_NORM=is_rms_norm,
146
+ num_warps=num_warps)
147
+ return out, mean, rstd
148
+
149
+
150
+
151
+ @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
152
+ @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
153
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
154
+ @triton.jit
155
+ def _layer_norm_bwd_kernel(
156
+ X, # pointer to the input
157
+ W, # pointer to the weights
158
+ B, # pointer to the biases
159
+ Z, # pointer to the other branch
160
+ Y, # pointer to the output to be recomputed
161
+ DY, # pointer to the output gradient
162
+ DX, # pointer to the input gradient
163
+ DW, # pointer to the partial sum of weights gradient
164
+ DB, # pointer to the partial sum of biases gradient
165
+ DZ, # pointer to the other branch
166
+ Mean, # pointer to the mean
167
+ Rstd, # pointer to the 1/std
168
+ stride_x_row, # how much to increase the pointer when moving by 1 row
169
+ stride_z_row,
170
+ stride_y_row,
171
+ stride_dy_row,
172
+ stride_dx_row,
173
+ stride_dz_row,
174
+ stride_dw_row,
175
+ stride_db_row,
176
+ M, # number of rows in X
177
+ N, # number of columns in X
178
+ eps, # epsilon to avoid division by zero
179
+ rows_per_program,
180
+ NORM_BEFORE_GATE: tl.constexpr,
181
+ IS_RMS_NORM: tl.constexpr,
182
+ HAS_BIAS: tl.constexpr,
183
+ HAS_Z: tl.constexpr,
184
+ RECOMPUTE_OUTPUT: tl.constexpr,
185
+ BLOCK_N: tl.constexpr,
186
+ ):
187
+ # Map the program id to the elements of X, DX, and DY it should compute.
188
+ row_block_id = tl.program_id(0)
189
+ group = tl.program_id(1)
190
+ row_start = row_block_id * rows_per_program
191
+ cols = tl.arange(0, BLOCK_N)
192
+ mask = cols < N
193
+ X += row_start * stride_x_row + group * N
194
+ if HAS_Z:
195
+ Z += row_start * stride_z_row + group * N
196
+ DZ += row_start * stride_dz_row + group * N
197
+ DY += row_start * stride_dy_row + group * N
198
+ DX += row_start * stride_dx_row + group * N
199
+ if RECOMPUTE_OUTPUT:
200
+ Y += row_start * stride_y_row + group * N
201
+ if not IS_RMS_NORM:
202
+ Mean += group * M
203
+ Rstd += group * M
204
+ W += group * N
205
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
206
+ if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
207
+ B += group * N
208
+ b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
209
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
210
+ if HAS_BIAS:
211
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
212
+ row_end = min((row_block_id + 1) * rows_per_program, M)
213
+ for row in range(row_start, row_end):
214
+ # Load data to SRAM
215
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
216
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
217
+ if not IS_RMS_NORM:
218
+ mean = tl.load(Mean + row)
219
+ if HAS_Z and not NORM_BEFORE_GATE:
220
+ z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
221
+ x_og = x
222
+ x = x_og * z * tl.sigmoid(z)
223
+ rstd = tl.load(Rstd + row)
224
+ # Compute dx
225
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
226
+ xhat = tl.where(mask, xhat, 0.)
227
+ if HAS_Z and NORM_BEFORE_GATE:
228
+ z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
229
+ z_sigmoid = tl.sigmoid(z)
230
+ y = xhat * w + b if HAS_BIAS else xhat * w
231
+ if RECOMPUTE_OUTPUT:
232
+ tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
233
+ dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
234
+ tl.store(DZ + cols, dz, mask=mask)
235
+ dy *= z * z_sigmoid
236
+ else:
237
+ if RECOMPUTE_OUTPUT:
238
+ y = xhat * w + b if HAS_BIAS else xhat * w
239
+ tl.store(Y + cols, y, mask=mask)
240
+ wdy = w * dy
241
+ c1 = tl.sum(xhat * wdy, axis=0) / N
242
+ if not IS_RMS_NORM:
243
+ c2 = tl.sum(wdy, axis=0) / N
244
+ dx = (wdy - (xhat * c1 + c2)) * rstd
245
+ else:
246
+ dx = (wdy - xhat * c1) * rstd
247
+ dw += dy * xhat
248
+ if HAS_BIAS:
249
+ db += dy
250
+ if HAS_Z and not NORM_BEFORE_GATE:
251
+ z_sigmoid = tl.sigmoid(z)
252
+ dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
253
+ tl.store(DZ + cols, dz, mask=mask)
254
+ dx *= z * z_sigmoid
255
+ # Write dx
256
+ tl.store(DX + cols, dx, mask=mask)
257
+
258
+ X += stride_x_row
259
+ if HAS_Z:
260
+ Z += stride_z_row
261
+ DZ += stride_dz_row
262
+ if RECOMPUTE_OUTPUT:
263
+ Y += stride_y_row
264
+ DY += stride_dy_row
265
+ DX += stride_dx_row
266
+ tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
267
+ if HAS_BIAS:
268
+ tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)
269
+
270
+
271
+ def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,
272
+ norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):
273
+ M, N = x.shape
274
+ if group_size is None:
275
+ group_size = N
276
+ assert N % group_size == 0
277
+ ngroups = N // group_size
278
+ assert x.stride(-1) == 1
279
+ assert dy.stride(-1) == 1
280
+ assert dy.shape == (M, N)
281
+ if z is not None:
282
+ assert z.stride(-1) == 1
283
+ assert z.shape == (M, N)
284
+ assert weight.shape == (N,)
285
+ assert weight.stride(-1) == 1
286
+ if bias is not None:
287
+ assert bias.stride(-1) == 1
288
+ assert bias.shape == (N,)
289
+ # allocate output
290
+ dx = torch.empty_like(x)
291
+ if dz is not None:
292
+ assert z is not None
293
+ assert dz.shape == z.shape
294
+ assert dz.stride(-1) == 1
295
+ else:
296
+ dz = torch.empty_like(z) if z is not None else None
297
+ if recompute_output:
298
+ if out is None:
299
+ out = torch.empty_like(x)
300
+ assert out.shape == x.shape
301
+
302
+ # Less than 64KB per feature: enqueue fused kernel
303
+ MAX_FUSED_SIZE = 65536 // x.element_size()
304
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
305
+ if group_size > BLOCK_N:
306
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
307
+ # heuristics for number of warps
308
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
309
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
310
+ # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
311
+ # would limit the occupancy.
312
+ nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
313
+ _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
314
+ _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
315
+ rows_per_program = math.ceil(M / nrow_groups)
316
+ grid = (nrow_groups, ngroups)
317
+ with torch.cuda.device(x.device.index):
318
+ _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,
319
+ dy, dx, _dw, _db, dz, mean, rstd,
320
+ x.stride(0),
321
+ z.stride(0) if z is not None else 0,
322
+ 0 if not recompute_output else out.stride(0),
323
+ dy.stride(0), dx.stride(0),
324
+ dz.stride(0) if dz is not None else 0,
325
+ _dw.stride(0),
326
+ _db.stride(0) if _db is not None else 0,
327
+ M, group_size, eps,
328
+ rows_per_program,
329
+ BLOCK_N=BLOCK_N,
330
+ NORM_BEFORE_GATE=norm_before_gate,
331
+ IS_RMS_NORM=is_rms_norm,
332
+ num_warps=num_warps)
333
+ dw = _dw.sum(0).to(weight.dtype)
334
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
335
+ return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
336
+
337
+
338
+ class LayerNormFn(torch.autograd.Function):
339
+
340
+ @staticmethod
341
+ def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,
342
+ is_rms_norm=False):
343
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
344
+ """
345
+
346
+ x_shape_og = x.shape
347
+ # reshape input data into 2D tensor
348
+ x = x.reshape(-1, x.shape[-1])
349
+ if x.stride(-1) != 1:
350
+ x = x.contiguous()
351
+ if z is not None:
352
+ assert z.shape == x_shape_og
353
+ z = z.reshape(-1, z.shape[-1])
354
+ if z.stride(-1) != 1:
355
+ z = z.contiguous()
356
+ weight = weight.contiguous()
357
+ if bias is not None:
358
+ bias = bias.contiguous()
359
+ y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)
360
+ ctx.save_for_backward(x, weight, bias, mean, rstd, z)
361
+ ctx.x_shape_og = x_shape_og
362
+ ctx.eps = eps
363
+ ctx.group_size = group_size
364
+ ctx.norm_before_gate = norm_before_gate
365
+ ctx.is_rms_norm = is_rms_norm
366
+ return y.reshape(x_shape_og)
367
+
368
+ @staticmethod
369
+ def backward(ctx, dy):
370
+ x, weight, bias, mean, rstd, z = ctx.saved_tensors
371
+ dy = dy.reshape(-1, dy.shape[-1])
372
+ if dy.stride(-1) != 1:
373
+ dy = dy.contiguous()
374
+ assert dy.shape == x.shape
375
+ dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size,
376
+ ctx.norm_before_gate, ctx.is_rms_norm)
377
+ return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None
378
+
379
+
380
+ def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
381
+ return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)
382
+
383
+
384
+ def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):
385
+ return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)
386
+
387
+
388
+ class LayerNorm(torch.nn.Module):
389
+
390
+ def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
391
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
392
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
393
+ """
394
+
395
+ factory_kwargs = {"device": device, "dtype": dtype}
396
+ super().__init__()
397
+ self.eps = eps
398
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
399
+ self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
400
+ self.group_size = group_size
401
+ self.norm_before_gate = norm_before_gate
402
+ self.reset_parameters()
403
+
404
+ def reset_parameters(self):
405
+ torch.nn.init.ones_(self.weight)
406
+ torch.nn.init.zeros_(self.bias)
407
+
408
+ def forward(self, x, z=None):
409
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
410
+ """
411
+ return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,
412
+ norm_before_gate=self.norm_before_gate)
413
+
414
+
415
+ class RMSNorm(torch.nn.Module):
416
+
417
+ def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
418
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
419
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
420
+ """
421
+ factory_kwargs = {"device": device, "dtype": dtype}
422
+ super().__init__()
423
+ self.eps = eps
424
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
425
+ self.register_parameter("bias", None)
426
+ self.group_size = group_size
427
+ self.norm_before_gate = norm_before_gate
428
+ self.reset_parameters()
429
+
430
+ def reset_parameters(self):
431
+ torch.nn.init.ones_(self.weight)
432
+
433
+ def forward(self, x, z=None):
434
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
435
+ """
436
+ return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
437
+ norm_before_gate=self.norm_before_gate)
mamba/build/lib/mamba_ssm/ops/triton/selective_state_update.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from mamba_ssm.ops.triton.softplus import softplus
16
+
17
+
18
+ @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
19
+ @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
20
+ @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
21
+ @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
22
+ @triton.jit
23
+ def _selective_scan_update_kernel(
24
+ # Pointers to matrices
25
+ state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
26
+ # Matrix dimensions
27
+ batch, nheads, dim, dstate, nheads_ngroups_ratio,
28
+ # Strides
29
+ stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
30
+ stride_x_batch, stride_x_head, stride_x_dim,
31
+ stride_dt_batch, stride_dt_head, stride_dt_dim,
32
+ stride_dt_bias_head, stride_dt_bias_dim,
33
+ stride_A_head, stride_A_dim, stride_A_dstate,
34
+ stride_B_batch, stride_B_group, stride_B_dstate,
35
+ stride_C_batch, stride_C_group, stride_C_dstate,
36
+ stride_D_head, stride_D_dim,
37
+ stride_z_batch, stride_z_head, stride_z_dim,
38
+ stride_out_batch, stride_out_head, stride_out_dim,
39
+ # Meta-parameters
40
+ DT_SOFTPLUS: tl.constexpr,
41
+ TIE_HDIM: tl.constexpr,
42
+ BLOCK_SIZE_M: tl.constexpr,
43
+ HAS_DT_BIAS: tl.constexpr,
44
+ HAS_D: tl.constexpr,
45
+ HAS_Z: tl.constexpr,
46
+ BLOCK_SIZE_DSTATE: tl.constexpr,
47
+ ):
48
+ pid_m = tl.program_id(axis=0)
49
+ pid_b = tl.program_id(axis=1)
50
+ pid_h = tl.program_id(axis=2)
51
+ state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
52
+ x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
53
+ dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
54
+ if HAS_DT_BIAS:
55
+ dt_bias_ptr += pid_h * stride_dt_bias_head
56
+ A_ptr += pid_h * stride_A_head
57
+ B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
58
+ C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
59
+ if HAS_Z:
60
+ z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
61
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
62
+
63
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
64
+ offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
65
+ state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
66
+ x_ptrs = x_ptr + offs_m * stride_x_dim
67
+ dt_ptrs = dt_ptr + offs_m * stride_dt_dim
68
+ if HAS_DT_BIAS:
69
+ dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
70
+ if HAS_D:
71
+ D_ptr += pid_h * stride_D_head
72
+ A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
73
+ B_ptrs = B_ptr + offs_n * stride_B_dstate
74
+ C_ptrs = C_ptr + offs_n * stride_C_dstate
75
+ if HAS_D:
76
+ D_ptrs = D_ptr + offs_m * stride_D_dim
77
+ if HAS_Z:
78
+ z_ptrs = z_ptr + offs_m * stride_z_dim
79
+ out_ptrs = out_ptr + offs_m * stride_out_dim
80
+
81
+ state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
82
+ x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
83
+ if not TIE_HDIM:
84
+ dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
85
+ if HAS_DT_BIAS:
86
+ dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
87
+ if DT_SOFTPLUS:
88
+ dt = softplus(dt)
89
+ A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
90
+ dA = tl.exp(A * dt[:, None])
91
+ else:
92
+ dt = tl.load(dt_ptr).to(tl.float32)
93
+ if HAS_DT_BIAS:
94
+ dt += tl.load(dt_bias_ptr).to(tl.float32)
95
+ if DT_SOFTPLUS:
96
+ dt = softplus(dt)
97
+ A = tl.load(A_ptr).to(tl.float32)
98
+ dA = tl.exp(A * dt) # scalar, not a matrix
99
+
100
+ B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
101
+ C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
102
+ if HAS_D:
103
+ D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
104
+ if HAS_Z:
105
+ z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
106
+
107
+ if not TIE_HDIM:
108
+ dB = B[None, :] * dt[:, None]
109
+ else:
110
+ dB = B * dt # vector of size (dstate,)
111
+ state = state * dA + dB * x[:, None]
112
+ tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
113
+ out = tl.sum(state * C[None, :], axis=1)
114
+ if HAS_D:
115
+ out += x * D
116
+ if HAS_Z:
117
+ out *= z * tl.sigmoid(z)
118
+ tl.store(out_ptrs, out, mask=offs_m < dim)
119
+
120
+
121
+ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
122
+ """
123
+ Argument:
124
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
125
+ x: (batch, dim) or (batch, nheads, dim)
126
+ dt: (batch, dim) or (batch, nheads, dim)
127
+ A: (dim, dstate) or (nheads, dim, dstate)
128
+ B: (batch, dstate) or (batch, ngroups, dstate)
129
+ C: (batch, dstate) or (batch, ngroups, dstate)
130
+ D: (dim,) or (nheads, dim)
131
+ z: (batch, dim) or (batch, nheads, dim)
132
+ dt_bias: (dim,) or (nheads, dim)
133
+ Return:
134
+ out: (batch, dim) or (batch, nheads, dim)
135
+ """
136
+ has_heads = state.dim() > 3
137
+ if state.dim() == 3:
138
+ state = state.unsqueeze(1)
139
+ if x.dim() == 2:
140
+ x = x.unsqueeze(1)
141
+ if dt.dim() == 2:
142
+ dt = dt.unsqueeze(1)
143
+ if A.dim() == 2:
144
+ A = A.unsqueeze(0)
145
+ if B.dim() == 2:
146
+ B = B.unsqueeze(1)
147
+ if C.dim() == 2:
148
+ C = C.unsqueeze(1)
149
+ if D is not None and D.dim() == 1:
150
+ D = D.unsqueeze(0)
151
+ if z is not None and z.dim() == 2:
152
+ z = z.unsqueeze(1)
153
+ if dt_bias is not None and dt_bias.dim() == 1:
154
+ dt_bias = dt_bias.unsqueeze(0)
155
+ batch, nheads, dim, dstate = state.shape
156
+ assert x.shape == (batch, nheads, dim)
157
+ assert dt.shape == x.shape
158
+ assert A.shape == (nheads, dim, dstate)
159
+ ngroups = B.shape[1]
160
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
161
+ assert B.shape == (batch, ngroups, dstate)
162
+ assert C.shape == B.shape
163
+ if D is not None:
164
+ assert D.shape == (nheads, dim)
165
+ if z is not None:
166
+ assert z.shape == x.shape
167
+ if dt_bias is not None:
168
+ assert dt_bias.shape == (nheads, dim)
169
+ out = torch.empty_like(x)
170
+ grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
171
+ z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
172
+ # We don't want autotune since it will overwrite the state
173
+ # We instead tune by hand.
174
+ BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
175
+ else ((16, 4) if dstate <= 32 else
176
+ ((8, 4) if dstate <= 64 else
177
+ ((4, 4) if dstate <= 128 else
178
+ ((4, 8))))))
179
+ tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
180
+ with torch.cuda.device(x.device.index):
181
+ _selective_scan_update_kernel[grid](
182
+ state, x, dt, dt_bias, A, B, C, D, z, out,
183
+ batch, nheads, dim, dstate, nheads // ngroups,
184
+ state.stride(0), state.stride(1), state.stride(2), state.stride(3),
185
+ x.stride(0), x.stride(1), x.stride(2),
186
+ dt.stride(0), dt.stride(1), dt.stride(2),
187
+ *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
188
+ A.stride(0), A.stride(1), A.stride(2),
189
+ B.stride(0), B.stride(1), B.stride(2),
190
+ C.stride(0), C.stride(1), C.stride(2),
191
+ *(D.stride(0), D.stride(1)) if D is not None else 0,
192
+ z_strides[0], z_strides[1], z_strides[2],
193
+ out.stride(0), out.stride(1), out.stride(2),
194
+ dt_softplus,
195
+ tie_hdim,
196
+ BLOCK_SIZE_M,
197
+ num_warps=num_warps,
198
+ )
199
+ if not has_heads:
200
+ out = out.squeeze(1)
201
+ return out
202
+
203
+
204
+ def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
205
+ """
206
+ Argument:
207
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
208
+ x: (batch, dim) or (batch, nheads, dim)
209
+ dt: (batch, dim) or (batch, nheads, dim)
210
+ A: (dim, dstate) or (nheads, dim, dstate)
211
+ B: (batch, dstate) or (batch, ngroups, dstate)
212
+ C: (batch, dstate) or (batch, ngroups, dstate)
213
+ D: (dim,) or (nheads, dim)
214
+ z: (batch, dim) or (batch, nheads, dim)
215
+ dt_bias: (dim,) or (nheads, dim)
216
+ Return:
217
+ out: (batch, dim) or (batch, nheads, dim)
218
+ """
219
+ has_heads = state.dim() > 3
220
+ if state.dim() == 3:
221
+ state = state.unsqueeze(1)
222
+ if x.dim() == 2:
223
+ x = x.unsqueeze(1)
224
+ if dt.dim() == 2:
225
+ dt = dt.unsqueeze(1)
226
+ if A.dim() == 2:
227
+ A = A.unsqueeze(0)
228
+ if B.dim() == 2:
229
+ B = B.unsqueeze(1)
230
+ if C.dim() == 2:
231
+ C = C.unsqueeze(1)
232
+ if D is not None and D.dim() == 1:
233
+ D = D.unsqueeze(0)
234
+ if z is not None and z.dim() == 2:
235
+ z = z.unsqueeze(1)
236
+ if dt_bias is not None and dt_bias.dim() == 1:
237
+ dt_bias = dt_bias.unsqueeze(0)
238
+ batch, nheads, dim, dstate = state.shape
239
+ assert x.shape == (batch, nheads, dim)
240
+ assert dt.shape == x.shape
241
+ assert A.shape == (nheads, dim, dstate)
242
+ ngroups = B.shape[1]
243
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
244
+ assert B.shape == (batch, ngroups, dstate)
245
+ assert C.shape == B.shape
246
+ if D is not None:
247
+ assert D.shape == (nheads, dim)
248
+ if z is not None:
249
+ assert z.shape == x.shape
250
+ if dt_bias is not None:
251
+ assert dt_bias.shape == (nheads, dim)
252
+ dt = dt + dt_bias
253
+ dt = F.softplus(dt) if dt_softplus else dt
254
+ dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
255
+ B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
256
+ C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
257
+ dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
258
+ state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
259
+ out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
260
+ if D is not None:
261
+ out += (x * D).to(out.dtype)
262
+ out = (out if z is None else out * F.silu(z)).to(x.dtype)
263
+ if not has_heads:
264
+ out = out.squeeze(1)
265
+ return out
mamba/build/lib/mamba_ssm/ops/triton/softplus.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+ from packaging import version
4
+
5
+ TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
6
+
7
+
8
+ if TRITON3:
9
+ @triton.jit
10
+ def softplus(dt):
11
+ dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
12
+ return dt
13
+ else:
14
+ @triton.jit
15
+ def softplus(dt):
16
+ dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
17
+ return dt
mamba/build/lib/mamba_ssm/ops/triton/ssd_bmm.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+
16
+ def init_to_zero(names):
17
+ return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
18
+
19
+
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
23
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
24
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
25
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
26
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
27
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
28
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
29
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
30
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
31
+ ],
32
+ key=['chunk_size', 'K', 'IS_CAUSAL'],
33
+ )
34
+ @triton.jit
35
+ def _bmm_chunk_fwd_kernel(
36
+ # Pointers to matrices
37
+ a_ptr, b_ptr, out_ptr, seq_idx_ptr,
38
+ # Matrix dimensions
39
+ seqlen, chunk_size, K, ngroups,
40
+ stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
41
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
42
+ stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
43
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
44
+ # Meta-parameters
45
+ IS_CAUSAL: tl.constexpr,
46
+ dot_dtype: tl.constexpr,
47
+ HAS_SEQ_IDX: tl.constexpr,
48
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
49
+ ):
50
+ pid_b = tl.program_id(axis=1)
51
+ pid_ch = tl.program_id(axis=2)
52
+ pid_c = pid_ch // ngroups
53
+ pid_h = pid_ch - pid_c * ngroups
54
+ num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
55
+ pid_m = tl.program_id(axis=0) // num_pid_n
56
+ pid_n = tl.program_id(axis=0) % num_pid_n
57
+ if IS_CAUSAL:
58
+ if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
59
+ return
60
+ a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
61
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
62
+ if HAS_SEQ_IDX:
63
+ seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
64
+
65
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
66
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
67
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
68
+ a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
69
+ b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
70
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
71
+
72
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
73
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
74
+ a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
75
+ b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
76
+ acc += tl.dot(a, b)
77
+ a_ptrs += BLOCK_SIZE_K * stride_ak
78
+ b_ptrs += BLOCK_SIZE_K * stride_bk
79
+
80
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
81
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
82
+ if HAS_SEQ_IDX:
83
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
84
+ seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
85
+ seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
86
+ acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
87
+ out = acc.to(out_ptr.dtype.element_ty)
88
+
89
+ out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
90
+ out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
91
+ tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
92
+
93
+
94
+ @triton.autotune(
95
+ configs=[
96
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
97
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
98
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
99
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
100
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
101
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
102
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
103
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
104
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
105
+ ],
106
+ key=['chunk_size', 'K'],
107
+ )
108
+ @triton.jit
109
+ def _bmm_chunk_bwd_kernel(
110
+ # Pointers to matrices
111
+ a_ptr, dout_ptr, db_ptr, res_ptr,
112
+ # Matrix dimensions
113
+ seqlen, chunk_size, K, ngroups,
114
+ stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
115
+ stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,
116
+ stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,
117
+ stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,
118
+ # Meta-parameters
119
+ dot_dtype: tl.constexpr,
120
+ HAS_RESIDUAL: tl.constexpr,
121
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,
122
+ ):
123
+ pid_b = tl.program_id(axis=1)
124
+ pid_ch = tl.program_id(axis=2)
125
+ pid_c = pid_ch // ngroups
126
+ pid_h = pid_ch - pid_c * ngroups
127
+ num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)
128
+ pid_m = tl.program_id(axis=0) // num_pid_n
129
+ pid_n = tl.program_id(axis=0) % num_pid_n
130
+
131
+ a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
132
+ dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head
133
+
134
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
135
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
136
+ offs_cs = tl.arange(0, BLOCK_SIZE_CS)
137
+ dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)
138
+ a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)
139
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
140
+
141
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
142
+ for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):
143
+ dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)
144
+ a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)
145
+ acc += tl.dot(dout, a)
146
+ dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m
147
+ a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen
148
+
149
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
150
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
151
+ if HAS_RESIDUAL:
152
+ res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head
153
+ res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)
154
+ res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)
155
+ acc += res
156
+ db = acc.to(db_ptr.dtype.element_ty)
157
+
158
+ db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head
159
+ db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)
160
+ tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))
161
+
162
+
163
+ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
164
+ """
165
+ Argument:
166
+ a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
167
+ b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
168
+ seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
169
+ causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
170
+ guaranteed to be correct.
171
+ Return:
172
+ out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
173
+ """
174
+ # Check constraints.
175
+ has_groups = a.dim() == 4
176
+ if not has_groups:
177
+ batch, seqlen, k = a.shape
178
+ else:
179
+ batch, seqlen, ngroups, k = a.shape
180
+ assert b.shape == a.shape
181
+ if seq_idx is not None:
182
+ assert seq_idx.shape == (batch, seqlen)
183
+ if a.stride(-1) != 1 and a.stride(1) != 1:
184
+ a = a.contiguous()
185
+ if b.stride(-1) != 1 and b.stride(1) != 1:
186
+ b = b.contiguous()
187
+ nchunks = math.ceil(seqlen / chunk_size)
188
+ # Allocates output.
189
+ out_dtype = a.dtype if output_dtype is None else output_dtype
190
+ out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
191
+ device=a.device, dtype=out_dtype)
192
+ dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
193
+ (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
194
+ grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
195
+ batch, nchunks if not has_groups else nchunks * ngroups)
196
+ with torch.cuda.device(a.device.index):
197
+ _bmm_chunk_fwd_kernel[grid](
198
+ a, b, out, seq_idx,
199
+ seqlen, chunk_size, k, ngroups if has_groups else 1,
200
+ a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
201
+ b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
202
+ out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
203
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
204
+ causal,
205
+ dot_dtype,
206
+ HAS_SEQ_IDX=seq_idx is not None,
207
+ )
208
+ return out
209
+
210
+
211
+ def _bmm_chunk_bwd(a, dout, residual=None, out=None):
212
+ """
213
+ Argument:
214
+ a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
215
+ dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
216
+ residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
217
+ Return:
218
+ out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
219
+
220
+ If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be
221
+ zeroed out before calling this function.
222
+ """
223
+ # Check constraints.
224
+ has_groups = a.dim() == 4
225
+ if not has_groups:
226
+ batch, seqlen, k = a.shape
227
+ else:
228
+ batch, seqlen, ngroups, k = a.shape
229
+ nchunks, chunk_size = dout.shape[1], dout.shape[-1]
230
+ if a.stride(-1) != 1 and a.stride(-2) != 1:
231
+ a = a.contiguous()
232
+ if dout.stride(-1) != 1 and dout.stride(-2) != 1:
233
+ dout = dout.contiguous()
234
+ if residual is not None:
235
+ assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)
236
+ if residual.stride(-1) != 1 and residual.stride(1) != 1:
237
+ residual = residual.contiguous()
238
+ # Allocates output.
239
+ if out is not None:
240
+ assert out.shape == a.shape
241
+ assert out.stride(-1) == 1 or out.stride(1) == 1
242
+ else:
243
+ out = torch.empty_like(a)
244
+ dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else
245
+ (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))
246
+ grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,
247
+ nchunks if not has_groups else nchunks * ngroups)
248
+ residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),
249
+ residual.stride(-1))
250
+ if residual is not None else (0, 0, 0, 0))
251
+ with torch.cuda.device(a.device.index):
252
+ _bmm_chunk_bwd_kernel[grid](
253
+ a, dout, out, residual,
254
+ seqlen, chunk_size, k, ngroups if has_groups else 1,
255
+ a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
256
+ dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),
257
+ out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),
258
+ residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],
259
+ dot_dtype,
260
+ HAS_RESIDUAL=residual is not None,
261
+ )
262
+ return out
mamba/build/lib/mamba_ssm/ops/triton/ssd_chunk_scan.py ADDED
The diff for this file is too large to render. See raw diff
 
mamba/build/lib/mamba_ssm/ops/triton/ssd_chunk_state.py ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from mamba_ssm.ops.triton.softplus import softplus
16
+
17
+
18
+ def init_to_zero(names):
19
+ return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
20
+
21
+ @triton.autotune(
22
+ configs=[
23
+ triton.Config({'BLOCK_SIZE_H': 1}),
24
+ triton.Config({'BLOCK_SIZE_H': 2}),
25
+ triton.Config({'BLOCK_SIZE_H': 4}),
26
+ triton.Config({'BLOCK_SIZE_H': 8}),
27
+ triton.Config({'BLOCK_SIZE_H': 16}),
28
+ triton.Config({'BLOCK_SIZE_H': 32}),
29
+ triton.Config({'BLOCK_SIZE_H': 64}),
30
+ ],
31
+ key=['chunk_size', 'nheads'],
32
+ )
33
+ @triton.jit
34
+ def _chunk_cumsum_fwd_kernel(
35
+ # Pointers to matrices
36
+ dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,
37
+ # Matrix dimension
38
+ batch, seqlen, nheads, chunk_size,
39
+ dt_min, dt_max,
40
+ # Strides
41
+ stride_dt_batch, stride_dt_seqlen, stride_dt_head,
42
+ stride_A_head,
43
+ stride_dt_bias_head,
44
+ stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,
45
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
46
+ # Meta-parameters
47
+ DT_SOFTPLUS: tl.constexpr,
48
+ HAS_DT_BIAS: tl.constexpr,
49
+ BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
50
+ ):
51
+ pid_b = tl.program_id(axis=0)
52
+ pid_c = tl.program_id(axis=1)
53
+ pid_h = tl.program_id(axis=2)
54
+ dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
55
+ dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
56
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
57
+
58
+ offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
59
+ offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
60
+ dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
61
+ A_ptrs = A_ptr + offs_h * stride_A_head
62
+ dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)
63
+ dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)
64
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
65
+
66
+ dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
67
+ if HAS_DT_BIAS:
68
+ dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
69
+ dt += dt_bias[:, None]
70
+ if DT_SOFTPLUS:
71
+ dt = softplus(dt)
72
+ # As of Triton 2.2.0, tl.clamp is not available yet
73
+ # dt = tl.clamp(dt, dt_min, dt_max)
74
+ dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
75
+ dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
76
+ tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
77
+ A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
78
+ dA = dt * A[:, None]
79
+ dA_cs = tl.cumsum(dA, axis=1)
80
+ tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
81
+
82
+
83
+ @triton.autotune(
84
+ configs=[
85
+ triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
86
+ triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
87
+ triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
88
+ triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
89
+ triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
90
+ triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
91
+ triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
92
+ ],
93
+ key=['chunk_size', 'nheads'],
94
+ )
95
+ @triton.jit
96
+ def _chunk_cumsum_bwd_kernel(
97
+ # Pointers to matrices
98
+ ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr,
99
+ ddt_ptr, dA_ptr, ddt_bias_ptr,
100
+ # Matrix dimensions
101
+ batch, seqlen, nheads, chunk_size,
102
+ dt_min, dt_max,
103
+ # Strides
104
+ stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize,
105
+ stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize,
106
+ stride_dt_batch, stride_dt_seqlen, stride_dt_head,
107
+ stride_A_head,
108
+ stride_dt_bias_head,
109
+ stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head,
110
+ stride_dA_head,
111
+ stride_ddt_bias_head,
112
+ # Meta-parameters
113
+ DT_SOFTPLUS: tl.constexpr,
114
+ HAS_DT_BIAS: tl.constexpr,
115
+ BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
116
+ ):
117
+ pid_b = tl.program_id(axis=0)
118
+ pid_c = tl.program_id(axis=1)
119
+ pid_h = tl.program_id(axis=2)
120
+ ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
121
+ ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
122
+ dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
123
+ ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
124
+
125
+ offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
126
+ offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
127
+ ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize)
128
+ ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize)
129
+ dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
130
+ ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen)
131
+ A_ptrs = A_ptr + offs_h * stride_A_head
132
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
133
+
134
+ ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
135
+ ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
136
+ A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
137
+ ddt = ddA * A[:, None] + ddt_out
138
+ dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
139
+ if HAS_DT_BIAS:
140
+ dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
141
+ dt += dt_bias[:, None]
142
+ if DT_SOFTPLUS:
143
+ dt_presoftplus = dt
144
+ dt = softplus(dt)
145
+ clamp_mask = (dt < dt_min) | (dt > dt_max)
146
+ # As of Triton 2.2.0, tl.clamp is not available yet
147
+ # dt = tl.clamp(dt, dt_min, dt_max)
148
+ dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
149
+ dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
150
+ ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0)
151
+ ddt = tl.where(clamp_mask, 0.0, ddt)
152
+ if DT_SOFTPLUS:
153
+ ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
154
+ tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit))
155
+ dA = tl.sum(ddA * dt, axis=1)
156
+ tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
157
+ if HAS_DT_BIAS:
158
+ ddt_bias = tl.sum(ddt, axis=1)
159
+ tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads)
160
+
161
+
162
+ @triton.autotune(
163
+ configs=[
164
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
165
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
166
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
167
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
168
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
169
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
170
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
171
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
172
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
173
+ ],
174
+ key=['hdim', 'dstate', 'chunk_size'],
175
+ )
176
+ @triton.jit
177
+ def _chunk_state_fwd_kernel(
178
+ # Pointers to matrices
179
+ x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
180
+ # Matrix dimensions
181
+ hdim, dstate, chunk_size,
182
+ batch, seqlen, nheads_ngroups_ratio,
183
+ # Strides
184
+ stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
185
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
186
+ stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
187
+ stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
188
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
189
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
190
+ # Meta-parameters
191
+ HAS_SEQ_IDX: tl.constexpr,
192
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
193
+ ):
194
+ pid_bc = tl.program_id(axis=1)
195
+ pid_c = pid_bc // batch
196
+ pid_b = pid_bc - pid_c * batch
197
+ pid_h = tl.program_id(axis=2)
198
+ num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
199
+ pid_m = tl.program_id(axis=0) // num_pid_n
200
+ pid_n = tl.program_id(axis=0) % num_pid_n
201
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
202
+ x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
203
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
204
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
205
+ if HAS_SEQ_IDX:
206
+ seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
207
+
208
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
209
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
210
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
211
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
212
+ b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
213
+ dt_ptrs = dt_ptr + offs_k * stride_dt_csize
214
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
215
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
216
+ if HAS_SEQ_IDX:
217
+ seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
218
+
219
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
220
+ if HAS_SEQ_IDX:
221
+ seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
222
+
223
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
224
+ for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
225
+ x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0)
226
+ b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
227
+ dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
228
+ if HAS_SEQ_IDX:
229
+ seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
230
+ dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
231
+ if not HAS_SEQ_IDX:
232
+ scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
233
+ else:
234
+ scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
235
+ b *= scale[:, None]
236
+ b = b.to(x_ptr.dtype.element_ty)
237
+ acc += tl.dot(x, b)
238
+ x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
239
+ b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
240
+ dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
241
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
242
+ if HAS_SEQ_IDX:
243
+ seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
244
+ states = acc.to(states_ptr.dtype.element_ty)
245
+
246
+ states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
247
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
248
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
249
+ states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
250
+ c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
251
+ tl.store(states_ptrs, states, mask=c_mask)
252
+
253
+
254
+ @triton.autotune(
255
+ configs=[
256
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
257
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
258
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
259
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
260
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
261
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
262
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
263
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
264
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
265
+ ],
266
+ key=['chunk_size', 'hdim', 'dstate'],
267
+ )
268
+ @triton.jit
269
+ def _chunk_state_bwd_dx_kernel(
270
+ # Pointers to matrices
271
+ x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr,
272
+ dx_ptr, ddt_ptr, ddA_cumsum_ptr,
273
+ # Matrix dimensions
274
+ chunk_size, hdim, dstate,
275
+ batch, seqlen, nheads_ngroups_ratio,
276
+ # Strides
277
+ stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
278
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
279
+ stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
280
+ stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
281
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
282
+ stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
283
+ stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
284
+ stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
285
+ # Meta-parameters
286
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
287
+ BLOCK_SIZE_DSTATE: tl.constexpr,
288
+ ):
289
+ pid_bc = tl.program_id(axis=1)
290
+ pid_c = pid_bc // batch
291
+ pid_b = pid_bc - pid_c * batch
292
+ pid_h = tl.program_id(axis=2)
293
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
294
+ pid_m = tl.program_id(axis=0) // num_pid_n
295
+ pid_n = tl.program_id(axis=0) % num_pid_n
296
+ x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
297
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
298
+ dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head
299
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
300
+ ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
301
+ ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
302
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
303
+
304
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
305
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
306
+
307
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
308
+ # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
309
+ offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
310
+ b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)
311
+ dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)
312
+ if BLOCK_SIZE_DSTATE <= 128:
313
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)
314
+ dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
315
+ dstates = dstates.to(b_ptr.dtype.element_ty)
316
+ acc = tl.dot(b, dstates)
317
+ else:
318
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
319
+ for k in range(0, dstate, BLOCK_SIZE_K):
320
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)
321
+ dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
322
+ dstates = dstates.to(b_ptr.dtype.element_ty)
323
+ acc += tl.dot(b, dstates)
324
+ b_ptrs += BLOCK_SIZE_K * stride_b_dstate
325
+ dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
326
+
327
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
328
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
329
+
330
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
331
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
332
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
333
+ dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
334
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
335
+ acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
336
+
337
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
338
+ x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
339
+ ddt = tl.sum(acc * x, axis=1)
340
+ ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
341
+ tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
342
+ ddA_cs = -(ddt * dt_m)
343
+ ddA_cs_last = -tl.sum(ddA_cs)
344
+ ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
345
+ tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
346
+ tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
347
+
348
+ dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
349
+ dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
350
+ dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
351
+ tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
352
+
353
+
354
+ @triton.autotune(
355
+ configs=[
356
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
357
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
358
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
359
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
360
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
361
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
362
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
363
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
364
+ ],
365
+ key=['chunk_size', 'dstate', 'hdim'],
366
+ )
367
+ @triton.jit
368
+ def _chunk_state_bwd_db_kernel(
369
+ # Pointers to matrices
370
+ x_ptr, dstates_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
371
+ db_ptr, ddA_cumsum_ptr,
372
+ # Matrix dimensions
373
+ chunk_size, dstate, hdim,
374
+ batch, seqlen, nheads, nheads_per_program, ngroups,
375
+ # Strides
376
+ stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
377
+ stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
378
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
379
+ stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
380
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
381
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
382
+ stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate,
383
+ stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
384
+ # Meta-parameters
385
+ HAS_DDA_CS: tl.constexpr,
386
+ HAS_SEQ_IDX: tl.constexpr,
387
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
388
+ ):
389
+ pid_bc = tl.program_id(axis=1)
390
+ pid_c = pid_bc // batch
391
+ pid_b = pid_bc - pid_c * batch
392
+ pid_sg = tl.program_id(axis=2)
393
+ pid_s = pid_sg // ngroups
394
+ pid_g = pid_sg - pid_s * ngroups
395
+ num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
396
+ pid_m = tl.program_id(axis=0) // num_pid_n
397
+ pid_n = tl.program_id(axis=0) % num_pid_n
398
+ x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
399
+ db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split
400
+ dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head
401
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
402
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
403
+ if HAS_DDA_CS:
404
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head
405
+ ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head
406
+ if HAS_SEQ_IDX:
407
+ seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
408
+
409
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
410
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
411
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
412
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim)
413
+ dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim)
414
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
415
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
416
+ if HAS_DDA_CS:
417
+ b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate)
418
+ ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
419
+
420
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
421
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
422
+ if HAS_DDA_CS:
423
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
424
+ if HAS_SEQ_IDX:
425
+ seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
426
+ seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
427
+ nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
428
+ for h in range(nheads_iter):
429
+ x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
430
+ dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
431
+ dstates = dstates.to(x_ptrs.dtype.element_ty)
432
+ db = tl.dot(x, dstates)
433
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
434
+ dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
435
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
436
+ if not HAS_SEQ_IDX:
437
+ scale = tl.exp(dA_cs_last - dA_cs_m)
438
+ else:
439
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
440
+ db *= (scale * dt_m)[:, None]
441
+ if HAS_DDA_CS:
442
+ # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
443
+ ddA_cs = tl.sum(db * b, axis=1)
444
+ tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
445
+ acc += db
446
+ x_ptrs += stride_x_head
447
+ dstates_ptrs += stride_states_head
448
+ dt_ptrs += stride_dt_head
449
+ dA_cumsum_ptr += stride_dA_cs_head
450
+ dA_cumsum_ptrs += stride_dA_cs_head
451
+ if HAS_DDA_CS:
452
+ ddA_cumsum_ptrs += stride_ddA_cs_head
453
+
454
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
455
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
456
+ # if HAS_SEQ_IDX:
457
+ # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
458
+ # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
459
+ # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
460
+ db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate)
461
+ tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate))
462
+
463
+
464
+ @triton.autotune(
465
+ configs=[
466
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
467
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
468
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
469
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
470
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
471
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
472
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
473
+ # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
474
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
475
+ triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
476
+ triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
477
+ triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
478
+ triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
479
+ triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
480
+ triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
481
+ triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
482
+ triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
483
+ ],
484
+ key=['chunk_size', 'hdim', 'dstate'],
485
+ )
486
+ @triton.jit
487
+ def _chunk_state_bwd_ddAcs_stable_kernel(
488
+ # Pointers to matrices
489
+ x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
490
+ ddA_cumsum_ptr,
491
+ # Matrix dimensions
492
+ chunk_size, hdim, dstate,
493
+ batch, seqlen, nheads_ngroups_ratio,
494
+ # Strides
495
+ stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
496
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
497
+ stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
498
+ stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
499
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
500
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
501
+ stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
502
+ # Meta-parameters
503
+ HAS_SEQ_IDX: tl.constexpr,
504
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
505
+ BLOCK_SIZE_DSTATE: tl.constexpr,
506
+ ):
507
+ pid_bc = tl.program_id(axis=1)
508
+ pid_c = pid_bc // batch
509
+ pid_b = pid_bc - pid_c * batch
510
+ pid_h = tl.program_id(axis=2)
511
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
512
+ pid_m = tl.program_id(axis=0) // num_pid_n
513
+ pid_n = tl.program_id(axis=0) % num_pid_n
514
+ x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
515
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
516
+ dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head
517
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
518
+ ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
519
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
520
+ if HAS_SEQ_IDX:
521
+ seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
522
+
523
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
524
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
525
+
526
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
527
+ # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
528
+ offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
529
+ b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)
530
+ dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)
531
+ if BLOCK_SIZE_DSTATE <= 128:
532
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)
533
+ dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
534
+ dstates = dstates.to(b_ptr.dtype.element_ty)
535
+ acc = tl.dot(b, dstates)
536
+ else:
537
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
538
+ for k in range(0, dstate, BLOCK_SIZE_K):
539
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)
540
+ dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
541
+ dstates = dstates.to(b_ptr.dtype.element_ty)
542
+ acc += tl.dot(b, dstates)
543
+ b_ptrs += BLOCK_SIZE_K * stride_b_dstate
544
+ dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
545
+
546
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
547
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
548
+
549
+ dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
550
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
551
+ if not HAS_SEQ_IDX:
552
+ scale = tl.exp(dA_cs_last - dA_cs_m)
553
+ else:
554
+ seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
555
+ seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
556
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
557
+ acc *= scale[:, None]
558
+
559
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
560
+ x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
561
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
562
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
563
+ ddt = tl.sum(acc * x, axis=1)
564
+ # ddA_cs = -(ddt * dt_m)
565
+ # Triton 2.2.0 errors if we have the cumsum here, so we just write it out
566
+ # then call torch.cumsum outside this kernel.
567
+ # ddA_cs = tl.cumsum(ddt * dt_m)
568
+ ddA_cs = ddt * dt_m
569
+ ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
570
+ # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
571
+ tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
572
+
573
+
574
+ @triton.autotune(
575
+ configs=[
576
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
577
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
578
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
579
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
580
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
581
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
582
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
583
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
584
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
585
+ ],
586
+ key=['hdim', 'dstate', 'chunk_size'],
587
+ )
588
+ @triton.jit
589
+ def _chunk_state_varlen_kernel(
590
+ # Pointers to matrices
591
+ x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr,
592
+ # Matrix dimensions
593
+ hdim, dstate, chunk_size,
594
+ seqlen, nheads_ngroups_ratio,
595
+ # Strides
596
+ stride_x_seqlen, stride_x_head, stride_x_hdim,
597
+ stride_b_seqlen, stride_b_head, stride_b_dstate,
598
+ stride_dt_chunk, stride_dt_head, stride_dt_csize,
599
+ stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
600
+ stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate,
601
+ stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate,
602
+ # Meta-parameters
603
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
604
+ ):
605
+ pid_b = tl.program_id(axis=1)
606
+ pid_h = tl.program_id(axis=2)
607
+ num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
608
+ pid_m = tl.program_id(axis=0) // num_pid_n
609
+ pid_n = tl.program_id(axis=0) % num_pid_n
610
+ end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
611
+ pid_c = (end_idx - 1) // chunk_size
612
+ b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
613
+ x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
614
+ dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
615
+ dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
616
+ chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
617
+
618
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
619
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
620
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
621
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
622
+ b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
623
+ dt_ptrs = dt_ptr + offs_k * stride_dt_csize
624
+ dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
625
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
626
+
627
+ chunk_size_limit = end_idx - pid_c * chunk_size
628
+ start_idx = tl.load(cu_seqlens_ptr + pid_b)
629
+ start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
630
+
631
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
632
+ for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
633
+ x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0)
634
+ b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32)
635
+ dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
636
+ dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
637
+ scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
638
+ tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
639
+ b *= scale[:, None]
640
+ b = b.to(x_ptr.dtype.element_ty)
641
+ acc += tl.dot(x, b)
642
+ x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
643
+ b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
644
+ dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
645
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
646
+
647
+ # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
648
+ if start_idx < pid_c * chunk_size:
649
+ chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate)
650
+ chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
651
+ # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
652
+ scale = tl.exp(dA_cs_last)
653
+ acc += chunk_states * scale
654
+
655
+ states = acc.to(states_ptr.dtype.element_ty)
656
+
657
+ states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
658
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
659
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
660
+ states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
661
+ c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
662
+ tl.store(states_ptrs, states, mask=c_mask)
663
+
664
+
665
+ def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
666
+ batch, seqlen, nheads = dt.shape
667
+ assert A.shape == (nheads,)
668
+ if dt_bias is not None:
669
+ assert dt_bias.shape == (nheads,)
670
+ nchunks = math.ceil(seqlen / chunk_size)
671
+ dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
672
+ dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
673
+ grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
674
+ with torch.cuda.device(dt.device.index):
675
+ _chunk_cumsum_fwd_kernel[grid_chunk_cs](
676
+ dt, A, dt_bias, dt_out, dA_cumsum,
677
+ batch, seqlen, nheads, chunk_size,
678
+ dt_limit[0], dt_limit[1],
679
+ dt.stride(0), dt.stride(1), dt.stride(2),
680
+ A.stride(0),
681
+ dt_bias.stride(0) if dt_bias is not None else 0,
682
+ dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),
683
+ dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
684
+ dt_softplus,
685
+ HAS_DT_BIAS=dt_bias is not None,
686
+ BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
687
+ )
688
+ return dA_cumsum, dt_out
689
+
690
+
691
+ def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), ddt=None):
692
+ batch, seqlen, nheads = dt.shape
693
+ _, _, nchunks, chunk_size = ddA.shape
694
+ assert ddA.shape == (batch, nheads, nchunks, chunk_size)
695
+ assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
696
+ assert A.shape == (nheads,)
697
+ if dt_bias is not None:
698
+ assert dt_bias.shape == (nheads,)
699
+ ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
700
+ else:
701
+ ddt_bias = None
702
+ if ddt is not None:
703
+ assert ddt.shape == dt.shape
704
+ else:
705
+ ddt = torch.empty_like(dt)
706
+ dA = torch.empty_like(A, dtype=torch.float32)
707
+ grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
708
+ with torch.cuda.device(dt.device.index):
709
+ _chunk_cumsum_bwd_kernel[grid_chunk_cs](
710
+ ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias,
711
+ batch, seqlen, nheads, chunk_size,
712
+ dt_limit[0], dt_limit[1],
713
+ ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3),
714
+ ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3),
715
+ dt.stride(0), dt.stride(1), dt.stride(2),
716
+ A.stride(0),
717
+ dt_bias.stride(0) if dt_bias is not None else 0,
718
+ ddt.stride(0), ddt.stride(1), ddt.stride(2),
719
+ dA.stride(0),
720
+ ddt_bias.stride(0) if ddt_bias is not None else 0,
721
+ dt_softplus,
722
+ HAS_DT_BIAS=dt_bias is not None,
723
+ BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
724
+ )
725
+ return ddt, dA, ddt_bias
726
+
727
+
728
+ def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True):
729
+ batch, seqlen, nheads, headdim = x.shape
730
+ _, _, nchunks, chunk_size = dt.shape
731
+ _, _, ngroups, dstate = B.shape
732
+ assert nheads % ngroups == 0
733
+ assert B.shape == (batch, seqlen, ngroups, dstate)
734
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
735
+ assert dA_cumsum.shape == dt.shape
736
+ if seq_idx is not None:
737
+ assert seq_idx.shape == (batch, seqlen)
738
+ if states is not None:
739
+ assert states.shape == (batch, nchunks, nheads, headdim, dstate)
740
+ else:
741
+ states_dtype = torch.float32 if states_in_fp32 else B.dtype
742
+ states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype)
743
+ grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
744
+ batch * nchunks, nheads)
745
+ with torch.cuda.device(x.device.index):
746
+ _chunk_state_fwd_kernel[grid](
747
+ x, B, states, dt, dA_cumsum, seq_idx,
748
+ headdim, dstate, chunk_size,
749
+ batch, seqlen, nheads // ngroups,
750
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3),
751
+ B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
752
+ states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
753
+ dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
754
+ dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
755
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
756
+ HAS_SEQ_IDX=seq_idx is not None,
757
+ )
758
+ return states
759
+
760
+
761
+ def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
762
+ batch, seqlen, nheads, headdim = x.shape
763
+ _, _, nchunks, chunk_size = dt.shape
764
+ _, _, ngroups, dstate = B.shape
765
+ assert nheads % ngroups == 0
766
+ assert B.shape == (batch, seqlen, ngroups, dstate)
767
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
768
+ assert dA_cumsum.shape == dt.shape
769
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
770
+ if dx is not None:
771
+ assert dx.shape == x.shape
772
+ else:
773
+ dx = torch.empty_like(x)
774
+ ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
775
+ ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32)
776
+ grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
777
+ batch * nchunks, nheads)
778
+ with torch.cuda.device(x.device.index):
779
+ _chunk_state_bwd_dx_kernel[grid_dx](
780
+ x, B, dstates, dt, dA_cumsum, dx, ddt, ddA_cumsum,
781
+ chunk_size, headdim, dstate,
782
+ batch, seqlen, nheads // ngroups,
783
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3),
784
+ B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
785
+ dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
786
+ dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
787
+ dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
788
+ dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
789
+ ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
790
+ ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
791
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
792
+ )
793
+ return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
794
+
795
+
796
+ def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
797
+ batch, seqlen, nheads, headdim = x.shape
798
+ _, _, nchunks, chunk_size = dt.shape
799
+ dstate = dstates.shape[-1]
800
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
801
+ assert dA_cumsum.shape == dt.shape
802
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
803
+ if seq_idx is not None:
804
+ assert seq_idx.shape == (batch, seqlen)
805
+ if B is not None:
806
+ assert B.shape == (batch, seqlen, ngroups, dstate)
807
+ B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
808
+ # Use torch.empty since the Triton kernel will call init_to_zero
809
+ ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
810
+ ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3))
811
+ else:
812
+ B_strides = (0, 0, 0, 0)
813
+ ddA_cumsum = None
814
+ ddA_cumsum_strides = (0, 0, 0, 0)
815
+ nheads_ngroups_ratio = nheads // ngroups
816
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
817
+ nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
818
+ nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
819
+ dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32)
820
+ grid_db = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
821
+ batch * nchunks, nsplits * ngroups)
822
+ with torch.cuda.device(x.device.index):
823
+ _chunk_state_bwd_db_kernel[grid_db](
824
+ x, dstates, B, dt, dA_cumsum, seq_idx, dB, ddA_cumsum,
825
+ chunk_size, dstate, headdim,
826
+ batch, seqlen, nheads, nheads_per_program, ngroups,
827
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3),
828
+ dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
829
+ *B_strides,
830
+ dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
831
+ dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
832
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
833
+ dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4),
834
+ *ddA_cumsum_strides,
835
+ HAS_DDA_CS=ddA_cumsum is not None,
836
+ HAS_SEQ_IDX=seq_idx is not None,
837
+ BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
838
+ )
839
+ dB = dB.sum(2)
840
+ if ddA_cumsum is not None:
841
+ # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
842
+ # to the state of the chunk.
843
+ # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
844
+ # But it's easier to just do the cumsum for all elements, the result will be the same.
845
+ torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
846
+ return dB if B is None else (dB, ddA_cumsum)
847
+
848
+
849
+ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
850
+ batch, seqlen, nheads, headdim = x.shape
851
+ _, _, nchunks, chunk_size = dt.shape
852
+ _, _, ngroups, dstate = B.shape
853
+ assert nheads % ngroups == 0
854
+ assert B.shape == (batch, seqlen, ngroups, dstate)
855
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
856
+ assert dA_cumsum.shape == dt.shape
857
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
858
+ if seq_idx is not None:
859
+ assert seq_idx.shape == (batch, seqlen)
860
+ # Use torch.empty since the Triton kernel will call init_to_zero
861
+ ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
862
+ grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
863
+ batch * nchunks, nheads)
864
+ with torch.cuda.device(x.device.index):
865
+ _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
866
+ x, B, dstates, dt, dA_cumsum, seq_idx, ddA_cumsum,
867
+ chunk_size, headdim, dstate,
868
+ batch, seqlen, nheads // ngroups,
869
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3),
870
+ B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
871
+ dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
872
+ dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
873
+ dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
874
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
875
+ ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
876
+ HAS_SEQ_IDX=seq_idx is not None,
877
+ BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
878
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
879
+ )
880
+ torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
881
+ return ddA_cumsum
882
+
883
+
884
+ def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
885
+ total_seqlen, nheads, headdim = x.shape
886
+ _, nchunks, chunk_size = dt.shape
887
+ _, ngroups, dstate = B.shape
888
+ batch = cu_seqlens.shape[0] - 1
889
+ cu_seqlens = cu_seqlens.contiguous()
890
+ assert nheads % ngroups == 0
891
+ assert B.shape == (total_seqlen, ngroups, dstate)
892
+ assert dt.shape == (nheads, nchunks, chunk_size)
893
+ assert dA_cumsum.shape == dt.shape
894
+ assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
895
+ states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device)
896
+ grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
897
+ batch, nheads)
898
+ with torch.cuda.device(x.device.index):
899
+ _chunk_state_varlen_kernel[grid](
900
+ x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states,
901
+ headdim, dstate, chunk_size,
902
+ total_seqlen, nheads // ngroups,
903
+ x.stride(0), x.stride(1), x.stride(2),
904
+ B.stride(0), B.stride(1), B.stride(2),
905
+ dt.stride(1), dt.stride(0), dt.stride(2),
906
+ dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2),
907
+ chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3),
908
+ states.stride(0), states.stride(1), states.stride(2), states.stride(3),
909
+ )
910
+ return states
911
+
912
+
913
+ class ChunkStateFn(torch.autograd.Function):
914
+
915
+ @staticmethod
916
+ def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
917
+ batch, seqlen, nheads, headdim = x.shape
918
+ _, _, nchunks, chunk_size = dt.shape
919
+ assert seqlen <= nchunks * chunk_size
920
+ _, _, ngroups, dstate = B.shape
921
+ assert B.shape == (batch, seqlen, ngroups, dstate)
922
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
923
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
924
+ if B.stride(-1) != 1:
925
+ B = B.contiguous()
926
+ if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
927
+ x = x.contiguous()
928
+ states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
929
+ ctx.save_for_backward(B, x, dt, dA_cumsum)
930
+ return states
931
+
932
+ @staticmethod
933
+ def backward(ctx, dstates):
934
+ B, x, dt, dA_cumsum = ctx.saved_tensors
935
+ batch, seqlen, nheads, headdim = x.shape
936
+ _, _, nchunks, chunk_size = dt.shape
937
+ _, _, ngroups, dstate = B.shape
938
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
939
+ if dstates.stride(-1) != 1:
940
+ dstates = dstates.contiguous()
941
+ dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
942
+ dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
943
+ dB = dB.to(B.dtype)
944
+ return dB, dx, ddt, ddA_cumsum, None
945
+
946
+
947
+ def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
948
+ """
949
+ Argument:
950
+ B: (batch, seqlen, ngroups, headdim)
951
+ x: (batch, seqlen, nheads, headdim)
952
+ dt: (batch, nheads, nchunks, chunk_size)
953
+ dA_cumsum: (batch, nheads, nchunks, chunk_size)
954
+ Return:
955
+ states: (batch, nchunks, nheads, headdim, dstate)
956
+ """
957
+ return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
958
+
959
+
960
+ def chunk_state_ref(B, x, dt, dA_cumsum):
961
+ """
962
+ Argument:
963
+ B: (batch, seqlen, ngroups, headdim)
964
+ x: (batch, seqlen, nheads, headdim)
965
+ dt: (batch, nheads, nchunks, chunk_size)
966
+ dA_cumsum: (batch, nheads, nchunks, chunk_size)
967
+ Return:
968
+ states: (batch, nchunks, nheads, headdim, dstate)
969
+ """
970
+ # Check constraints.
971
+ batch, seqlen, nheads, headdim = x.shape
972
+ dstate = B.shape[-1]
973
+ _, _, nchunks, chunk_size = dt.shape
974
+ assert seqlen <= nchunks * chunk_size
975
+ assert x.shape == (batch, seqlen, nheads, headdim)
976
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
977
+ ngroups = B.shape[2]
978
+ assert nheads % ngroups == 0
979
+ assert B.shape == (batch, seqlen, ngroups, dstate)
980
+ B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
981
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
982
+ if seqlen < nchunks * chunk_size:
983
+ x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
984
+ B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
985
+ x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
986
+ B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
987
+ decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
988
+ return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x)
mamba/build/lib/mamba_ssm/ops/triton/ssd_combined.py ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ from typing import Optional
7
+
8
+ import math
9
+ from packaging import version
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+ from torch.cuda.amp import custom_bwd, custom_fwd
15
+
16
+ import triton
17
+ import triton.language as tl
18
+
19
+ from einops import rearrange, repeat
20
+
21
+ try:
22
+ from causal_conv1d import causal_conv1d_fn
23
+ import causal_conv1d_cuda
24
+ except ImportError:
25
+ causal_conv1d_fn, causal_conv1d_cuda = None, None
26
+
27
+ from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
28
+ from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
29
+ from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
30
+ from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
31
+ from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
32
+ from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen
33
+ from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd
34
+ from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref
35
+ from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
36
+ from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
37
+ from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
38
+ from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
39
+ from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
40
+ from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
41
+ from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd
42
+
43
+ TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
44
+
45
+
46
+ def init_to_zero(names):
47
+ return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
48
+
49
+
50
+ @triton.autotune(
51
+ configs=[
52
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
53
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
54
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
55
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
56
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
57
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
58
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
59
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
60
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
61
+ ],
62
+ key=['chunk_size', 'hdim', 'dstate'],
63
+ )
64
+ @triton.jit
65
+ def _chunk_scan_chunk_state_bwd_dx_kernel(
66
+ # Pointers to matrices
67
+ x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,
68
+ b_ptr, dstates_ptr,
69
+ dx_ptr, ddt_ptr, dD_ptr,
70
+ # Matrix dimensions
71
+ chunk_size, hdim, dstate,
72
+ batch, seqlen, nheads_ngroups_ratio,
73
+ # Strides
74
+ stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
75
+ stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
76
+ stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
77
+ stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
78
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
79
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
80
+ stride_D_head,
81
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
82
+ stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,
83
+ stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
84
+ stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
85
+ stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
86
+ # Meta-parameters
87
+ HAS_D: tl.constexpr,
88
+ D_HAS_HDIM: tl.constexpr,
89
+ HAS_SEQ_IDX: tl.constexpr,
90
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
91
+ BLOCK_SIZE_DSTATE: tl.constexpr,
92
+ IS_TRITON_22: tl.constexpr,
93
+ ):
94
+ pid_bc = tl.program_id(axis=1)
95
+ pid_c = pid_bc // batch
96
+ pid_b = pid_bc - pid_c * batch
97
+ pid_h = tl.program_id(axis=2)
98
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
99
+ pid_m = tl.program_id(axis=0) // num_pid_n
100
+ pid_n = tl.program_id(axis=0) % num_pid_n
101
+ x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
102
+ cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
103
+ dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
104
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
105
+ ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
106
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
107
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
108
+ dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head
109
+ if HAS_SEQ_IDX:
110
+ seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
111
+
112
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
113
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
114
+
115
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
116
+
117
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
118
+
119
+ dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
120
+
121
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
122
+ if not HAS_SEQ_IDX:
123
+ scale = tl.exp(dA_cs_last - dA_cs_m)
124
+ else:
125
+ seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
126
+ seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
127
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
128
+ # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
129
+ # However, we're getting error with the Triton compiler 2.1.0 for that code path:
130
+ # Unexpected mma -> mma layout conversion
131
+ # Triton 2.2.0 fixes this
132
+ offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
133
+ b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)
134
+ dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)
135
+ if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
136
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)
137
+ dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
138
+ dstates = dstates.to(b_ptr.dtype.element_ty)
139
+ acc = tl.dot(b, dstates) * scale[:, None]
140
+ else:
141
+ for k in range(0, dstate, BLOCK_SIZE_K):
142
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)
143
+ dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
144
+ dstates = dstates.to(b_ptr.dtype.element_ty)
145
+ acc += tl.dot(b, dstates)
146
+ b_ptrs += BLOCK_SIZE_K * stride_b_dstate
147
+ dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
148
+ acc *= scale[:, None]
149
+
150
+ # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
151
+ # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
152
+ # dt_ptrs = dt_ptr + offs_m * stride_dt_csize
153
+ # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
154
+ # ddt = tl.sum(acc * x, axis=1) * dt_m
155
+ # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
156
+ # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
157
+
158
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
159
+ cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
160
+ dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
161
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
162
+ K_MAX = chunk_size_limit
163
+ K_MIN = pid_m * BLOCK_SIZE_M
164
+ cb_ptrs += K_MIN * stride_cb_csize_k
165
+ dout_ptrs += K_MIN * stride_dout_seqlen
166
+ dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
167
+ for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
168
+ k = tl.multiple_of(k, BLOCK_SIZE_K)
169
+ # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
170
+ cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
171
+ dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
172
+ dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
173
+ cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
174
+ # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
175
+ # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
176
+ # Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
177
+ # This will cause NaN in acc, and hence NaN in dx and ddt.
178
+ mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
179
+ cb = tl.where(mask, cb, 0.0)
180
+ cb = cb.to(dout_ptr.dtype.element_ty)
181
+ acc += tl.dot(cb, dout)
182
+ cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
183
+ dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
184
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
185
+
186
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
187
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
188
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
189
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
190
+ dx = acc * dt_m[:, None]
191
+ dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
192
+ dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
193
+ if HAS_D:
194
+ dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
195
+ dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
196
+ if D_HAS_HDIM:
197
+ D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
198
+ else:
199
+ D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
200
+ dx += dout_res * D
201
+ tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
202
+
203
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
204
+ x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
205
+ if HAS_D:
206
+ dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
207
+ if D_HAS_HDIM:
208
+ dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
209
+ dD = tl.sum(dout_res * x, axis=0)
210
+ tl.store(dD_ptrs, dD, mask=offs_n < hdim)
211
+ else:
212
+ dD = tl.sum(dout_res * x)
213
+ tl.store(dD_ptr, dD)
214
+ ddt = tl.sum(acc * x, axis=1)
215
+ ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
216
+ tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
217
+
218
+
219
+ def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):
220
+ batch, seqlen, nheads, headdim = x.shape
221
+ _, _, nchunks, chunk_size = dt.shape
222
+ _, _, ngroups, dstate = B.shape
223
+ assert nheads % ngroups == 0
224
+ assert B.shape == (batch, seqlen, ngroups, dstate)
225
+ assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
226
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
227
+ assert dA_cumsum.shape == dt.shape
228
+ assert dout.shape == x.shape
229
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
230
+ if seq_idx is not None:
231
+ assert seq_idx.shape == (batch, seqlen)
232
+ if D is not None:
233
+ assert D.shape == (nheads, headdim) or D.shape == (nheads,)
234
+ assert D.stride(-1) == 1
235
+ BLOCK_SIZE_min = 32
236
+ dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
237
+ headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
238
+ else:
239
+ dD = None
240
+ dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
241
+ if D is not None else (0, 0, 0, 0, 0))
242
+ if dx is None:
243
+ dx = torch.empty_like(x)
244
+ else:
245
+ assert dx.shape == x.shape
246
+ ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
247
+ grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
248
+ batch * nchunks, nheads)
249
+ with torch.cuda.device(x.device.index):
250
+ _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
251
+ x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,
252
+ chunk_size, headdim, dstate,
253
+ batch, seqlen, nheads // ngroups,
254
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3),
255
+ CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),
256
+ dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
257
+ dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
258
+ dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
259
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
260
+ D.stride(0) if D is not None else 0,
261
+ B.stride(0), B.stride(1), B.stride(2), B.stride(3),
262
+ dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
263
+ dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
264
+ ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
265
+ dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
266
+ D is not None,
267
+ D.dim() == 2 if D is not None else True,
268
+ HAS_SEQ_IDX=seq_idx is not None,
269
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
270
+ IS_TRITON_22=TRITON_22
271
+ )
272
+ if D is not None:
273
+ BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"]
274
+ n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
275
+ dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
276
+ if D.dim() == 1:
277
+ dD = rearrange(dD, "h 1 -> h")
278
+ return dx, ddt.to(dtype=dt.dtype), dD
279
+
280
+
281
+ def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
282
+ batch, seqlen, nheads, headdim = x.shape
283
+ _, _, ngroups, dstate = B.shape
284
+ assert nheads % ngroups == 0
285
+ assert B.shape == (batch, seqlen, ngroups, dstate)
286
+ assert x.shape == (batch, seqlen, nheads, headdim)
287
+ assert dt.shape == (batch, seqlen, nheads)
288
+ assert A.shape == (nheads,)
289
+ assert C.shape == B.shape
290
+ if z is not None:
291
+ assert z.shape == x.shape
292
+ if D is not None:
293
+ assert D.shape == (nheads, headdim) or D.shape == (nheads,)
294
+ if seq_idx is not None:
295
+ assert seq_idx.shape == (batch, seqlen)
296
+ if B.stride(-1) != 1:
297
+ B = B.contiguous()
298
+ if C.stride(-1) != 1:
299
+ C = C.contiguous()
300
+ if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
301
+ x = x.contiguous()
302
+ if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous
303
+ z = z.contiguous()
304
+ if D is not None and D.stride(-1) != 1:
305
+ D = D.contiguous()
306
+ if initial_states is not None:
307
+ assert initial_states.shape == (batch, nheads, headdim, dstate)
308
+ # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
309
+ # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
310
+ # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
311
+ # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
312
+ dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
313
+ states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
314
+ # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
315
+ # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
316
+ # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
317
+ states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
318
+ initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
319
+ seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype)
320
+ states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]]
321
+ # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
322
+ # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
323
+ CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
324
+ out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx)
325
+ if cu_seqlens is None:
326
+ return out, out_x, dt, dA_cumsum, states, final_states
327
+ else:
328
+ assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
329
+ varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0),
330
+ cu_seqlens, states.squeeze(0))
331
+ return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
332
+
333
+
334
+ def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None,
335
+ dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False,
336
+ dt_limit=(0.0, float("inf")),
337
+ dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False):
338
+ if dout.stride(-1) != 1:
339
+ dout = dout.contiguous()
340
+ batch, seqlen, nheads, headdim = x.shape
341
+ nchunks = math.ceil(seqlen / chunk_size)
342
+ _, _, ngroups, dstate = B.shape
343
+ assert dout.shape == (batch, seqlen, nheads, headdim)
344
+ assert dt.shape == (batch, seqlen, nheads)
345
+ assert A.shape == (nheads,)
346
+ assert nheads % ngroups == 0
347
+ assert B.shape == (batch, seqlen, ngroups, dstate)
348
+ assert C.shape == B.shape
349
+ assert out.shape == x.shape
350
+ if initial_states is not None:
351
+ assert initial_states.shape == (batch, nheads, headdim, dstate)
352
+ if seq_idx is not None:
353
+ assert seq_idx.shape == (batch, seqlen)
354
+ if dx is not None:
355
+ assert dx.shape == x.shape
356
+ if dB is not None:
357
+ assert dB.shape == B.shape
358
+ dB_given = dB
359
+ else:
360
+ dB_given = torch.empty_like(B)
361
+ if dC is not None:
362
+ assert dC.shape == C.shape
363
+ dC_given = dC
364
+ else:
365
+ dC_given = torch.empty_like(C)
366
+ if dz is not None:
367
+ assert z is not None
368
+ assert dz.shape == z.shape
369
+ if ddt is not None:
370
+ assert ddt.shape == dt.shape
371
+ ddt_given = ddt
372
+ else:
373
+ ddt_given = torch.empty_like(dt)
374
+ # TD: For some reason Triton (2.1.0 and 2.2.0) errors with
375
+ # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
376
+ dt_in = dt.clone()
377
+ dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus,
378
+ dt_limit=dt_limit)
379
+ CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
380
+ states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
381
+ states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
382
+ initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
383
+ seq_idx=seq_idx, chunk_size=chunk_size)
384
+ states = rearrange(states, "... (p n) -> ... p n", n=dstate)
385
+ if z is not None:
386
+ dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output)
387
+ outz = rest[0] if recompute_output else out
388
+ else:
389
+ dz = None
390
+ outz = out
391
+ dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype)
392
+ # dstates has length nchunks, containing the gradient to initial states at index 0 and
393
+ # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
394
+ # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
395
+ # will be used in matmul in the next kernels.
396
+ dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
397
+ rearrange(states, "... p n -> ... (p n)"),
398
+ dA_cumsum[:, :, :, -1],
399
+ rearrange(dstates, "... p n -> ... (p n)"),
400
+ dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None,
401
+ seq_idx=seq_idx,
402
+ has_initial_states=initial_states is not None,
403
+ dstates_dtype=x.dtype,
404
+ states_dtype=x.dtype,
405
+ chunk_size=chunk_size,
406
+ )
407
+ # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
408
+ # gradient to the final states at index (nchunks - 1)
409
+ # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
410
+ # The final states is not stored.
411
+ states = rearrange(states, "... (p n) -> ... p n", n=dstate)
412
+ dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
413
+ dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None
414
+ dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx)
415
+ # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
416
+ dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups)
417
+ # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
418
+ dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups)
419
+ # Computing ddA with the dcb kernel is much slower, so we're not using it for now
420
+ dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
421
+ # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
422
+ dCB = dCB.to(CB.dtype)
423
+ _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
424
+ _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
425
+ # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
426
+ # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
427
+ if z is None:
428
+ dD = dD_from_x
429
+ # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
430
+ # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
431
+ # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
432
+ # be a lot of underflow.
433
+
434
+ # This is already done as part of bwd_dC kernel
435
+ # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
436
+ ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
437
+ ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
438
+ # This is already done as part of bwd_dB kernel
439
+ # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
440
+ # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
441
+ ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
442
+ ddA += ddA_next + ddA_prev
443
+
444
+ ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given)
445
+
446
+ # These 2 lines are just to test ddt and dA being computed by old code
447
+ # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
448
+ # ddt_given.copy_(ddt)
449
+
450
+ return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states)
451
+ return return_vals if not recompute_output else (*return_vals, outz)
452
+
453
+
454
+ def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
455
+ """
456
+ Argument:
457
+ dout: (batch, seqlen, nheads, headdim)
458
+ x: (batch, seqlen, nheads, headdim)
459
+ dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
460
+ A: (nheads) or (dim, dstate)
461
+ B: (batch, seqlen, ngroups, dstate)
462
+ C: (batch, seqlen, ngroups, dstate)
463
+ D: (nheads, headdim) or (nheads,)
464
+ z: (batch, seqlen, nheads, headdim)
465
+ Return:
466
+ out: (batch, seqlen, nheads, headdim)
467
+ """
468
+ import selective_scan
469
+
470
+ batch, seqlen, nheads, headdim = x.shape
471
+ chunk_size = dt.shape[-1]
472
+ _, _, ngroups, dstate = B.shape
473
+ assert nheads % ngroups == 0
474
+ x = rearrange(x, "b l h p -> b (h p) l")
475
+ squeeze_dt = dt.dim() == 4
476
+ if dt.dim() == 4:
477
+ dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
478
+ dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
479
+ squeeze_A = A.dim() == 1
480
+ if A.dim() == 1:
481
+ A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
482
+ else:
483
+ A = A.to(dtype=torch.float32)
484
+ B = rearrange(B, "b l g n -> b g n l")
485
+ C = rearrange(C, "b l g n -> b g n l")
486
+ if D is not None:
487
+ if D.dim() == 2:
488
+ D = rearrange(D, "h p -> (h p)")
489
+ else:
490
+ D = repeat(D, "h -> (h p)", p=headdim)
491
+ if z is not None:
492
+ z = rearrange(z, "b l h p -> b (h p) l")
493
+
494
+ if x.stride(-1) != 1:
495
+ x = x.contiguous()
496
+ if dt.stride(-1) != 1:
497
+ dt = dt.contiguous()
498
+ if D is not None:
499
+ D = D.contiguous()
500
+ if B.stride(-1) != 1:
501
+ B = B.contiguous()
502
+ if C.stride(-1) != 1:
503
+ C = C.contiguous()
504
+ if z is not None and z.stride(-1) != 1:
505
+ z = z.contiguous()
506
+ _, intermediate, *rest = selective_scan.fwd(x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False)
507
+ if z is not None:
508
+ out = rest[0]
509
+ else:
510
+ out = None
511
+
512
+ dout = rearrange(dout, "b l h p -> b (h p) l")
513
+
514
+ if dout.stride(-1) != 1:
515
+ dout = dout.contiguous()
516
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
517
+ # backward of selective_scan with the backward of chunk).
518
+ # Here we just pass in None and dz will be allocated in the C++ code.
519
+ _, ddt, dA, *rest = selective_scan.bwd(
520
+ x, dt.to(dtype=x.dtype), A, B, C, D, z, None, dout, intermediate, out, None, False,
521
+ False # option to recompute out_z, not used here
522
+ )
523
+ ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
524
+ if squeeze_dt:
525
+ ddt = ddt.float().sum(dim=2)
526
+ if squeeze_A:
527
+ dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
528
+ return ddt, dA
529
+
530
+
531
+ class MambaChunkScanCombinedFn(torch.autograd.Function):
532
+
533
+ @staticmethod
534
+ def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False):
535
+ ctx.dt_dtype = dt.dtype
536
+ if not return_varlen_states:
537
+ cu_seqlens = None
538
+ else:
539
+ assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
540
+ out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
541
+ ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx)
542
+ ctx.dt_softplus = dt_softplus
543
+ ctx.chunk_size = chunk_size
544
+ ctx.dt_limit = dt_limit
545
+ ctx.return_final_states = return_final_states
546
+ ctx.return_varlen_states = return_varlen_states
547
+ if not return_varlen_states:
548
+ return out if not return_final_states else (out, final_states)
549
+ else:
550
+ varlen_states = rest[0]
551
+ return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states)
552
+
553
+ @staticmethod
554
+ def backward(ctx, dout, *args):
555
+ out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors
556
+ assert not ctx.return_varlen_states, "return_varlen_states is not supported in backward"
557
+ dfinal_states = args[0] if ctx.return_final_states else None
558
+ dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit)
559
+ return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None, None, None
560
+
561
+
562
+ def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False):
563
+ """
564
+ Argument:
565
+ x: (batch, seqlen, nheads, headdim)
566
+ dt: (batch, seqlen, nheads)
567
+ A: (nheads)
568
+ B: (batch, seqlen, ngroups, dstate)
569
+ C: (batch, seqlen, ngroups, dstate)
570
+ chunk_size: int
571
+ D: (nheads, headdim) or (nheads,)
572
+ z: (batch, seqlen, nheads, headdim)
573
+ dt_bias: (nheads,)
574
+ initial_states: (batch, nheads, headdim, dstate)
575
+ seq_idx: (batch, seqlen)
576
+ cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
577
+ dt_softplus: Whether to apply softplus to dt
578
+ Return:
579
+ out: (batch, seqlen, nheads, headdim)
580
+ """
581
+ return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states)
582
+
583
+
584
+ def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):
585
+ """
586
+ Argument:
587
+ x: (batch, seqlen, nheads, headdim)
588
+ dt: (batch, seqlen, nheads)
589
+ A: (nheads)
590
+ B: (batch, seqlen, ngroups, dstate)
591
+ C: (batch, seqlen, ngroups, dstate)
592
+ D: (nheads, headdim) or (nheads,)
593
+ z: (batch, seqlen, nheads, headdim)
594
+ dt_bias: (nheads,)
595
+ Return:
596
+ out: (batch, seqlen, nheads, headdim)
597
+ """
598
+ batch, seqlen, nheads, headdim = x.shape
599
+ dstate = B.shape[-1]
600
+ if seqlen % chunk_size != 0:
601
+ dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
602
+ dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
603
+ dt = dt.float() # We want high precision for this before cumsum
604
+ if dt_bias is not None:
605
+ dt = dt + rearrange(dt_bias, "h -> h 1 1")
606
+ if dt_softplus:
607
+ dt = F.softplus(dt)
608
+ dA = dt * rearrange(A, "h -> h 1 1")
609
+ dA = dt * rearrange(A, "h -> h 1 1")
610
+ dA_cumsum = torch.cumsum(dA, dim=-1)
611
+ # 1. Compute the state for each chunk
612
+ states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
613
+ # 2. Pass the state to all the chunks by weighted cumsum.
614
+ states = rearrange(state_passing(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0],
615
+ "... (p n) -> ... p n", n=dstate)
616
+ # 3. Compute the output for each chunk
617
+ out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
618
+ return out
619
+
620
+
621
+ def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):
622
+ """
623
+ Argument:
624
+ x: (batch, seqlen, nheads, headdim)
625
+ dt: (batch, seqlen, nheads)
626
+ A: (nheads)
627
+ B: (batch, seqlen, ngroups, dstate)
628
+ C: (batch, seqlen, ngroups, dstate)
629
+ D: (nheads, headdim) or (nheads,)
630
+ z: (batch, seqlen, nheads, headdim)
631
+ dt_bias: (nheads,)
632
+ Return:
633
+ out: (batch, seqlen, nheads, headdim)
634
+ """
635
+ batch, seqlen, nheads, headdim = x.shape
636
+ dstate = B.shape[-1]
637
+ if seqlen % chunk_size != 0:
638
+ dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
639
+ dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
640
+ dt = dt.float() # We want high precision for this before cumsum
641
+ if dt_bias is not None:
642
+ dt = dt + rearrange(dt_bias, "h -> h 1 1")
643
+ if dt_softplus:
644
+ dt = F.softplus(dt)
645
+ dA = dt * rearrange(A, "h -> h 1 1")
646
+ dA_cumsum = torch.cumsum(dA, dim=-1)
647
+ # 1. Compute the state for each chunk
648
+ states = chunk_state_ref(B, x, dt, dA_cumsum)
649
+ states_dtype = states.dtype
650
+ if states.dtype not in [torch.float32, torch.float64]:
651
+ states = states.to(torch.float32)
652
+ # 2. Pass the state to all the chunks by weighted cumsum.
653
+ # state_passing_ref is much less numerically stable
654
+ states = rearrange(state_passing_ref(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0],
655
+ "... (p n) -> ... p n", n=dstate)
656
+ states = states.to(states_dtype)
657
+ # 3. Compute the output for each chunk
658
+ out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
659
+ return out
660
+
661
+
662
+ def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
663
+ """
664
+ Argument:
665
+ x: (batch, seqlen, nheads, headdim)
666
+ dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
667
+ A: (nheads) or (dim, dstate)
668
+ B: (batch, seqlen, ngroups, dstate)
669
+ C: (batch, seqlen, ngroups, dstate)
670
+ D: (nheads, headdim) or (nheads,)
671
+ z: (batch, seqlen, nheads, headdim)
672
+ dt_bias: (nheads,) or (nheads, headdim)
673
+ Return:
674
+ out: (batch, seqlen, nheads, headdim)
675
+ """
676
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
677
+
678
+ batch, seqlen, nheads, headdim = x.shape
679
+ _, _, ngroups, dstate = B.shape
680
+ x = rearrange(x, "b l h p -> b (h p) l")
681
+ if dt.dim() == 3:
682
+ dt = repeat(dt, "b l h -> b l h p", p=headdim)
683
+ dt = rearrange(dt, "b l h p -> b (h p) l")
684
+ if A.dim() == 1:
685
+ A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
686
+ else:
687
+ A = A.to(dtype=torch.float32)
688
+ B = rearrange(B, "b l g n -> b g n l")
689
+ C = rearrange(C, "b l g n -> b g n l")
690
+ if D is not None:
691
+ if D.dim() == 2:
692
+ D = rearrange(D, "h p -> (h p)")
693
+ else:
694
+ D = repeat(D, "h -> (h p)", p=headdim)
695
+ if z is not None:
696
+ z = rearrange(z, "b l h p -> b (h p) l")
697
+ if dt_bias is not None:
698
+ if dt_bias.dim() == 1:
699
+ dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
700
+ dt_bias = rearrange(dt_bias, "h p -> (h p)")
701
+ if dt_limit != (0.0, float("inf")):
702
+ if dt_bias is not None:
703
+ dt = dt + rearrange(dt_bias, "d -> d 1")
704
+ if dt_softplus:
705
+ dt = F.softplus(dt)
706
+ dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
707
+ dt_bias = None
708
+ dt_softplus = None
709
+ out = selective_scan_fn(x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus)
710
+ return rearrange(out, "b (h p) l -> b l h p", p=headdim)
711
+
712
+
713
+ def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=None, z=None,
714
+ dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")),
715
+ activation="silu", headdim=None, ngroups=1):
716
+ """
717
+ Argument:
718
+ xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
719
+ conv1d_weight: (dim + 2 * ngroups * dstate, width)
720
+ conv1d_bias: (dim + 2 * ngroups * dstate,)
721
+ dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
722
+ A: (nheads)
723
+ D: (nheads, headdim) or (nheads,)
724
+ z: (batch, seqlen, dim)
725
+ dt_bias: (nheads) or (nheads, headdim)
726
+ headdim: if D is 1D and z is None, headdim must be passed in
727
+ Return:
728
+ out: (batch, seqlen, dim)
729
+ """
730
+ batch, seqlen, nheads = dt.shape[:3]
731
+ assert nheads % ngroups == 0
732
+ if z is not None:
733
+ dim = z.shape[-1]
734
+ assert dim % nheads == 0
735
+ headdim = dim // nheads
736
+ else:
737
+ if D.dim() == 1:
738
+ assert headdim is not None
739
+ else:
740
+ headdim = D.shape[1]
741
+ dim = nheads * headdim
742
+ xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation),
743
+ "b d s -> b s d")
744
+ dstate = (xBC.shape[-1] - dim) // ngroups // 2
745
+ x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
746
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
747
+ B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
748
+ C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
749
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
750
+ out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
751
+ return rearrange(out, "b s h p -> b s (h p)")
752
+
753
+
754
+ class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
755
+
756
+ @staticmethod
757
+ @custom_fwd
758
+ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
759
+ rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None,
760
+ ngroups=1, norm_before_gate=True):
761
+ assert activation in [None, "silu", "swish"]
762
+ if D.dim() == 1:
763
+ assert headdim is not None
764
+ nheads, = D.shape
765
+ else:
766
+ nheads, headdim = D.shape
767
+ batch, seqlen, _ = zxbcdt.shape
768
+ dim = nheads * headdim
769
+ assert nheads % ngroups == 0
770
+ dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
771
+ d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
772
+ assert d_nonssm >= 0
773
+ assert zxbcdt.shape == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads)
774
+ assert dt_bias.shape == (nheads,)
775
+ assert A.shape == (nheads,)
776
+ zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1)
777
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
778
+ xBC_conv = rearrange(
779
+ causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
780
+ conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
781
+ "b d s -> b s d"
782
+ )
783
+ x, B, C = torch.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
784
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
785
+ B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
786
+ C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
787
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
788
+ if rmsnorm_weight is None:
789
+ out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
790
+ out = rearrange(out, "b s h p -> b s (h p)")
791
+ rstd = None
792
+ if d_nonssm > 0:
793
+ out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
794
+ else:
795
+ out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
796
+ # reshape input data into 2D tensor
797
+ x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
798
+ z_rms = rearrange(z, "b s h p -> (b s) (h p)")
799
+ rmsnorm_weight = rmsnorm_weight.contiguous()
800
+ if d_nonssm == 0:
801
+ out = None
802
+ else:
803
+ out01 = torch.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype, device=x_rms.device)
804
+ out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
805
+ _swiglu_fwd(zx0, out=out01[..., :d_nonssm])
806
+ out, _, rstd = _layer_norm_fwd(x_rms, rmsnorm_weight, None, rmsnorm_eps, z_rms, out=out,
807
+ group_size=dim // ngroups,
808
+ norm_before_gate=norm_before_gate, is_rms_norm=True)
809
+ if d_nonssm == 0:
810
+ out = rearrange(out, "(b s) d -> b s d", b=batch)
811
+ else:
812
+ out = out01
813
+ ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None
814
+ if outproj_weight is not None:
815
+ if torch.is_autocast_enabled():
816
+ dtype = torch.get_autocast_gpu_dtype()
817
+ out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
818
+ outproj_bias = outproj_bias.to(dtype) if outproj_bias is not None else None
819
+ out = F.linear(out, outproj_weight, outproj_bias)
820
+ else:
821
+ assert outproj_bias is None
822
+ ctx.save_for_backward(zxbcdt, conv1d_weight, conv1d_bias,
823
+ out_x, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias)
824
+ ctx.dt_limit = dt_limit
825
+ ctx.return_final_states = return_final_states
826
+ ctx.activation = activation
827
+ ctx.rmsnorm_eps = rmsnorm_eps
828
+ ctx.norm_before_gate = norm_before_gate
829
+ ctx.chunk_size = chunk_size
830
+ ctx.headdim = headdim
831
+ ctx.ngroups = ngroups
832
+ return out if not return_final_states else (out, final_states)
833
+
834
+ @staticmethod
835
+ @custom_bwd
836
+ def backward(ctx, dout, *args):
837
+ zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors
838
+ dfinal_states = args[0] if ctx.return_final_states else None
839
+ headdim = ctx.headdim
840
+ nheads = D.shape[0]
841
+ dim = nheads * headdim
842
+ assert nheads % ctx.ngroups == 0
843
+ dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
844
+ d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
845
+ assert d_nonssm >= 0
846
+ recompute_output = outproj_weight is not None
847
+ if recompute_output:
848
+ out_recompute = torch.empty(*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype)
849
+ out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], dim=-1)
850
+ zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
851
+ # Recompute x, B, C
852
+ xBC_conv = rearrange(
853
+ causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
854
+ conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
855
+ "b d s -> b s d"
856
+ )
857
+ x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
858
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
859
+ B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
860
+ C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
861
+ dzxbcdt = torch.empty_like(zxbcdt)
862
+ dzx0, dz, dxBC_given, ddt_given = torch.split(dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
863
+ dxBC = torch.empty_like(xBC)
864
+ dx, dB, dC = torch.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
865
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
866
+ dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
867
+ dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
868
+ dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
869
+ if outproj_weight is not None:
870
+ dout_og = dout
871
+ dout = F.linear(dout, outproj_weight.t())
872
+ if d_nonssm > 0:
873
+ dout0, dout = dout.split([d_nonssm, dim], dim=-1)
874
+ _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
875
+ dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
876
+ if rmsnorm_weight is None:
877
+ dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
878
+ dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = _mamba_chunk_scan_combined_bwd(
879
+ dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC, dz=dz, recompute_output=recompute_output
880
+ )
881
+ out_for_linear = rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
882
+ drmsnorm_weight = None
883
+ else:
884
+ batch = dout.shape[0]
885
+ dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
886
+ dz = rearrange(dz, "b l d -> (b l) d")
887
+ x_rms = rearrange(out, "b s h p -> (b s) (h p)")
888
+ z_rms = rearrange(z, "b s h p -> (b s) (h p)")
889
+ out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None
890
+ dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None)
891
+ out_for_linear = out_recompute if recompute_output else None
892
+ dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
893
+ dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
894
+ dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC
895
+ )
896
+
897
+ if outproj_weight is not None:
898
+ doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
899
+ doutproj_bias = dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
900
+ else:
901
+ doutproj_weight, doutproj_bias = None, None
902
+ dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
903
+ dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
904
+ rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
905
+ rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"]
906
+ )
907
+ dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
908
+ return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None
909
+
910
+
911
+ def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
912
+ """
913
+ Argument:
914
+ zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
915
+ conv1d_weight: (dim + 2 * ngroups * dstate, width)
916
+ conv1d_bias: (dim + 2 * ngroups * dstate,)
917
+ dt_bias: (nheads,)
918
+ A: (nheads)
919
+ D: (nheads, headdim) or (nheads,)
920
+ initial_states: (batch, nheads, headdim, dstate)
921
+ seq_idx: (batch, seqlen), int32
922
+ rmsnorm_weight: (dim,)
923
+ outproj_weight: (out_dim, dim)
924
+ outproj_bias: (out_dim,)
925
+ headdim: if D is 1D, headdim must be passed in
926
+ norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
927
+ Return:
928
+ out: (batch, seqlen, dim)
929
+ """
930
+ return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
931
+
932
+
933
+ def mamba_split_conv1d_scan_ref(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, dt_limit=(0.0, float("inf")), activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
934
+ """
935
+ Argument:
936
+ zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
937
+ conv1d_weight: (dim + 2 * ngroups * dstate, width)
938
+ conv1d_bias: (dim + 2 * ngroups * dstate,)
939
+ dt_bias: (nheads,)
940
+ A: (nheads)
941
+ D: (nheads, headdim) or (nheads,)
942
+ rmsnorm_weight: (dim,)
943
+ outproj_weight: (out_dim, dim)
944
+ outproj_bias: (out_dim,)
945
+ headdim: if D is 1D, headdim must be passed in
946
+ norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
947
+ Return:
948
+ out: (batch, seqlen, dim)
949
+ """
950
+ if D.dim() == 1:
951
+ assert headdim is not None
952
+ nheads, = D.shape
953
+ else:
954
+ nheads, headdim = D.shape
955
+ assert nheads % ngroups == 0
956
+ batch, seqlen, _ = zxbcdt.shape
957
+ dim = nheads * headdim
958
+ dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
959
+ assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
960
+ assert dt_bias.shape == (nheads,)
961
+ assert A.shape == (nheads,)
962
+ if rmsnorm_weight is not None:
963
+ assert rmsnorm_weight.shape == (dim,)
964
+ z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
965
+ xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation),
966
+ "b d s -> b s d")
967
+ x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
968
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
969
+ B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
970
+ C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
971
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
972
+ out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(),
973
+ z=z if rmsnorm_weight is None else None, dt_bias=dt_bias, dt_softplus=True, dt_limit=dt_limit)
974
+ out = rearrange(out, "b s h p -> b s (h p)")
975
+ if rmsnorm_weight is not None:
976
+ out = rmsnorm_fn(out, rmsnorm_weight, None, z=rearrange(z, "b l h p -> b l (h p)"), eps=rmsnorm_eps,
977
+ norm_before_gate=norm_before_gate)
978
+ if outproj_weight is not None:
979
+ out = F.linear(out, outproj_weight, outproj_bias)
980
+ return out
981
+
mamba/build/lib/mamba_ssm/ops/triton/ssd_state_passing.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({'BLOCK_SIZE': 64}),
19
+ triton.Config({'BLOCK_SIZE': 128}),
20
+ triton.Config({'BLOCK_SIZE': 256}),
21
+ triton.Config({'BLOCK_SIZE': 512}),
22
+ triton.Config({'BLOCK_SIZE': 1024}),
23
+ triton.Config({'BLOCK_SIZE': 2048}),
24
+ ],
25
+ key=['dim'],
26
+ )
27
+ @triton.jit
28
+ def _state_passing_fwd_kernel(
29
+ # Pointers to matrices
30
+ states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,
31
+ # Matrix dimensions
32
+ dim, nchunks, seqlen, chunk_size,
33
+ # Strides
34
+ stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,
35
+ stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
36
+ stride_final_states_batch, stride_final_states_head, stride_final_states_dim,
37
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
38
+ stride_initstates_batch, stride_initstates_head, stride_initstates_dim,
39
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
40
+ # Meta-parameters
41
+ HAS_INITSTATES: tl.constexpr,
42
+ HAS_SEQ_IDX: tl.constexpr,
43
+ BLOCK_SIZE: tl.constexpr,
44
+ ):
45
+ pid_b = tl.program_id(axis=1)
46
+ pid_h = tl.program_id(axis=2)
47
+ pid_m = tl.program_id(axis=0)
48
+ states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
49
+ dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
50
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
51
+ final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
52
+ if HAS_INITSTATES:
53
+ initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head
54
+ if HAS_SEQ_IDX:
55
+ seq_idx_ptr += pid_b * stride_seq_idx_batch
56
+
57
+ offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
58
+ states_ptrs = states_ptr + offs_m * stride_states_dim
59
+ out_ptrs = out_ptr + offs_m * stride_out_dim
60
+ final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
61
+
62
+ if not HAS_INITSTATES:
63
+ states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
64
+ else:
65
+ initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim
66
+ states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
67
+ tl.store(out_ptrs, states, mask=offs_m < dim)
68
+ out_ptrs += stride_out_chunk
69
+ seq_idx = 0
70
+ for c in range(nchunks):
71
+ new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
72
+ dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
73
+ scale = tl.exp(dA_cs)
74
+ if HAS_SEQ_IDX:
75
+ seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
76
+ scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
77
+ seq_idx = seq_idx_new
78
+ states = scale * states + new_states
79
+ if c < nchunks - 1:
80
+ tl.store(out_ptrs, states, mask=offs_m < dim)
81
+ else:
82
+ tl.store(final_states_ptrs, states, mask=offs_m < dim)
83
+ states_ptrs += stride_states_chunk
84
+ dA_cs_ptr += stride_dA_cs_chunk
85
+ out_ptrs += stride_out_chunk
86
+
87
+
88
+ @triton.autotune(
89
+ configs=[
90
+ triton.Config({'BLOCK_SIZE': 64}),
91
+ triton.Config({'BLOCK_SIZE': 128}),
92
+ triton.Config({'BLOCK_SIZE': 256}),
93
+ triton.Config({'BLOCK_SIZE': 512}),
94
+ triton.Config({'BLOCK_SIZE': 1024}),
95
+ triton.Config({'BLOCK_SIZE': 2048}),
96
+ ],
97
+ key=['dim'],
98
+ )
99
+ @triton.jit
100
+ def _state_passing_bwd_kernel(
101
+ # Pointers to matrices
102
+ dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,
103
+ dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,
104
+ # Matrix dimensions
105
+ dim, nchunks, seqlen, chunk_size,
106
+ # Strides
107
+ stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,
108
+ stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
109
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
110
+ stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,
111
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
112
+ stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,
113
+ stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,
114
+ stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,
115
+ # Meta-parameters
116
+ CONVERT_STATES: tl.constexpr,
117
+ HAS_DFINAL_STATES: tl.constexpr,
118
+ HAS_DINITSTATES: tl.constexpr,
119
+ HAS_SEQ_IDX: tl.constexpr,
120
+ BLOCK_SIZE: tl.constexpr,
121
+ ):
122
+ pid_b = tl.program_id(axis=1)
123
+ pid_h = tl.program_id(axis=2)
124
+ pid_m = tl.program_id(axis=0)
125
+ dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk
126
+ dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk
127
+ ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m
128
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
129
+ dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk
130
+ if CONVERT_STATES:
131
+ states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
132
+ if HAS_DFINAL_STATES:
133
+ dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head
134
+ if HAS_DINITSTATES:
135
+ dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head
136
+ if HAS_SEQ_IDX:
137
+ seq_idx_ptr += pid_b * stride_seq_idx_batch
138
+
139
+ offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
140
+ dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim
141
+ out_ptrs = out_ptr + offs_m * stride_out_dim
142
+ dout_ptrs = dout_ptr + offs_m * stride_dout_dim
143
+ if CONVERT_STATES:
144
+ states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim
145
+
146
+ if HAS_DFINAL_STATES:
147
+ dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)
148
+ else:
149
+ dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
150
+ tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
151
+ if HAS_SEQ_IDX:
152
+ seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)
153
+ dstates_ptrs -= stride_dstates_chunk
154
+ for c in range(nchunks - 1):
155
+ dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
156
+ scale = tl.exp(dA_cs)
157
+ if HAS_SEQ_IDX:
158
+ seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))
159
+ scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
160
+ seq_idx = seq_idx_new
161
+ out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
162
+ if CONVERT_STATES:
163
+ tl.store(states_converted_ptrs, out, mask=offs_m < dim)
164
+ ddA = tl.sum(out * dstates) * scale
165
+ tl.store(ddA_cs_ptr, ddA)
166
+ dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
167
+ dstates = scale * dstates + dout
168
+ tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
169
+ dout_ptrs -= stride_dout_chunk
170
+ dstates_ptrs -= stride_dstates_chunk
171
+ dA_cs_ptr -= stride_dA_cs_chunk
172
+ ddA_cs_ptr -= stride_ddA_cs_chunk
173
+ out_ptrs -= stride_out_chunk
174
+ if CONVERT_STATES:
175
+ states_converted_ptrs -= stride_out_chunk
176
+ if CONVERT_STATES:
177
+ out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
178
+ tl.store(states_converted_ptrs, out, mask=offs_m < dim)
179
+ if not HAS_DINITSTATES:
180
+ tl.store(ddA_cs_ptr, 0.0)
181
+ else:
182
+ dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
183
+ scale = tl.exp(dA_cs)
184
+ if HAS_SEQ_IDX:
185
+ scale = tl.where(seq_idx == 0, scale, 0.0)
186
+ out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
187
+ ddA = tl.sum(out * dstates) * scale
188
+ tl.store(ddA_cs_ptr, ddA)
189
+ dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
190
+ dstates = scale * dstates + dout
191
+ tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)
192
+
193
+
194
+ def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,
195
+ out_dtype=None):
196
+ batch, nchunks, nheads, dim = states.shape
197
+ assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
198
+ if initial_states is not None:
199
+ assert initial_states.shape == (batch, nheads, dim)
200
+ if seq_idx is not None:
201
+ assert chunk_size is not None
202
+ seqlen = seq_idx.shape[-1]
203
+ assert seq_idx.shape == (batch, seqlen)
204
+ out_dtype = states.dtype if out_dtype is None else out_dtype
205
+ out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
206
+ final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)
207
+ grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
208
+ with torch.cuda.device(states.device.index):
209
+ _state_passing_fwd_kernel[grid](
210
+ states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,
211
+ dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
212
+ states.stride(0), states.stride(1), states.stride(2), states.stride(3),
213
+ out.stride(0), out.stride(1), out.stride(2), out.stride(3),
214
+ final_states.stride(0), final_states.stride(1), final_states.stride(2),
215
+ dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
216
+ *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
217
+ if initial_states is not None else (0, 0, 0)),
218
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
219
+ HAS_INITSTATES=initial_states is not None,
220
+ HAS_SEQ_IDX=seq_idx is not None,
221
+ )
222
+ return out, final_states
223
+
224
+
225
+ def _state_passing_bwd(
226
+ states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,
227
+ dstates_dtype=None, states_dtype=None, chunk_size=None
228
+ ):
229
+ """
230
+ states contains the initial_states at index 0. The final states are not included in states.
231
+ """
232
+ batch, nchunks, nheads, dim = states.shape
233
+ assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
234
+ assert dout.shape == (batch, nchunks, nheads, dim)
235
+ if seq_idx is not None:
236
+ assert chunk_size is not None
237
+ seqlen = seq_idx.shape[-1]
238
+ assert seq_idx.shape == (batch, seqlen)
239
+ dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
240
+ if states_dtype is not None and states_dtype != states.dtype:
241
+ states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
242
+ assert states_converted.stride() == states.stride()
243
+ else:
244
+ states_converted = None
245
+ if has_initial_states:
246
+ dinitstates = torch.empty_like(dstates[:, 0])
247
+ else:
248
+ dinitstates = None
249
+ if dfinal_states is not None:
250
+ assert dfinal_states.shape == (batch, nheads, dim)
251
+ BLOCK_SIZE_min = 64
252
+ n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min
253
+ ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,
254
+ dtype=torch.float32, device=dA_chunk_cumsum.device)
255
+ grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
256
+ with torch.cuda.device(dout.device.index):
257
+ _state_passing_bwd_kernel[grid](
258
+ dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,
259
+ dstates, ddA_chunk_cumsum, dinitstates, states_converted,
260
+ dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
261
+ dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
262
+ states.stride(0), states.stride(1), states.stride(2), states.stride(3),
263
+ dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
264
+ *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))
265
+ if dfinal_states is not None else (0, 0, 0)),
266
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
267
+ dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),
268
+ ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),
269
+ *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))
270
+ if dinitstates is not None else (0, 0, 0)),
271
+ CONVERT_STATES=states_converted is not None,
272
+ HAS_DFINAL_STATES=dfinal_states is not None,
273
+ HAS_DINITSTATES=dinitstates is not None,
274
+ HAS_SEQ_IDX=seq_idx is not None,
275
+ )
276
+ BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"]
277
+ n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
278
+ ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)
279
+ if states_dtype is not None and states_dtype == states.dtype:
280
+ states_converted = states
281
+ return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)
282
+
283
+
284
+ class StatePassingFn(torch.autograd.Function):
285
+
286
+ @staticmethod
287
+ def forward(ctx, states, dA_chunk_cumsum, initial_states=None):
288
+ batch, nchunks, nheads, dim = states.shape
289
+ assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
290
+ if states.stride(-1) != 1:
291
+ states = states.contiguous()
292
+ out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states)
293
+ ctx.save_for_backward(out, dA_chunk_cumsum)
294
+ ctx.has_initial_states = initial_states is not None
295
+ return out, final_states
296
+
297
+ @staticmethod
298
+ def backward(ctx, dout, dfinal_states):
299
+ out, dA_chunk_cumsum = ctx.saved_tensors
300
+ batch, nchunks, nheads, dim = out.shape
301
+ assert dout.shape == (batch, nchunks, nheads, dim)
302
+ assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
303
+ assert dfinal_states.shape == (batch, nheads, dim)
304
+ if dout.stride(-1) != 1:
305
+ dout = dout.contiguous()
306
+ dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd(
307
+ out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states
308
+ )
309
+ return dstates, ddA_chunk_cumsum, dinitstates
310
+
311
+
312
+ def state_passing(states, dA_chunk_cumsum, initial_states=None):
313
+ """
314
+ Argument:
315
+ states: (batch, nchunks, nheads, dim)
316
+ dA_chunk_cumsum: (batch, nheads, nchunks)
317
+ initial_states: (batch, nheads, dim)
318
+ Return:
319
+ out: (batch, nchunks, nheads, dim)
320
+ final_states: (batch, nheads, dim)
321
+ """
322
+ return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states)
323
+
324
+
325
+ def state_passing_ref(states, dA_chunk_cumsum, initial_states=None):
326
+ """
327
+ Argument:
328
+ states: (batch, nchunks, nheads, dim)
329
+ dA_chunk_cumsum: (batch, nheads, nchunks)
330
+ initial_states: (batch, nheads, dim)
331
+ Return:
332
+ out: (batch, nchunks, nheads, dim)
333
+ final_states: (batch, nheads, dim)
334
+ """
335
+ if initial_states is None:
336
+ initial_states = torch.zeros_like(states[:, 0])
337
+ states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1)
338
+ dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0))
339
+ dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1)
340
+ nchunks = dA_chunk_cumsum.shape[-1]
341
+ # (batch, nheads, nchunks, nchunks)
342
+ dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :]
343
+ # (batch, nheads, nchunks, nchunks)
344
+ decay_chunk = torch.exp(dt_chunk_segment_sum)
345
+ causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0)
346
+ decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)
347
+ out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states)
348
+ return out[:, :-1], out[:, -1]
mamba/build/lib/mamba_ssm/utils/__init__.py ADDED
File without changes
mamba/build/lib/mamba_ssm/utils/generation.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import gc
3
+ import time
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Callable, Optional, Sequence, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+ from torch import Tensor
13
+ from torch.profiler import ProfilerActivity, profile, record_function
14
+ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
15
+
16
+
17
+ @dataclass
18
+ class InferenceParams:
19
+ """Inference parameters that are passed to the main model in order
20
+ to efficienly calculate and store the context during inference."""
21
+
22
+ max_seqlen: int
23
+ max_batch_size: int
24
+ seqlen_offset: int = 0
25
+ batch_size_offset: int = 0
26
+ key_value_memory_dict: dict = field(default_factory=dict)
27
+ lengths_per_sample: Optional[Tensor] = None
28
+
29
+ def reset(self, max_seqlen, max_batch_size):
30
+ self.max_seqlen = max_seqlen
31
+ self.max_batch_size = max_batch_size
32
+ self.seqlen_offset = 0
33
+ if self.lengths_per_sample is not None:
34
+ self.lengths_per_sample.zero_()
35
+
36
+
37
+ def modify_logits_for_min_p_filtering(logits, min_p):
38
+ """Set the logits for none min_p values to -inf. Done in-place."""
39
+ if min_p <= 0.0 or min_p >= 1.0:
40
+ return
41
+ indices_to_remove = logits < min_p
42
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
43
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
44
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
45
+ def modify_logits_for_top_k_filtering(logits, top_k):
46
+ """Set the logits for none top-k values to -inf. Done in-place."""
47
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
48
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
49
+
50
+
51
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
52
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
53
+ def modify_logits_for_top_p_filtering(logits, top_p):
54
+ """Set the logits for none top-p values to -inf. Done in-place."""
55
+ if top_p <= 0.0 or top_p >= 1.0:
56
+ return
57
+ # First sort and calculate cumulative sum of probabilities.
58
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
59
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
60
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
61
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
62
+ # scatter sorted tensors to original indexing
63
+ indices_to_remove = sorted_indices_to_remove.scatter(
64
+ 1, sorted_indices, sorted_indices_to_remove
65
+ )
66
+ logits.masked_fill_(indices_to_remove, float("-inf"))
67
+
68
+
69
+ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
70
+ """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
71
+ logits: (batch_size, vocab_size)
72
+ prev_output_tokens: (batch_size, seq_len)
73
+ """
74
+ if repetition_penalty == 1.0:
75
+ return logits
76
+ score = torch.gather(logits, 1, prev_output_tokens)
77
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
78
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
79
+ logits.scatter_(1, prev_output_tokens, score)
80
+ return logits
81
+
82
+
83
+ def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
84
+ """Sample from top-k logits.
85
+ Arguments:
86
+ logits: Tensor of shape (batch_size, vocab_size)
87
+ """
88
+ if top_k == 1: # Short-circuit for greedy decoding
89
+ return logits.argmax(dim=-1)
90
+ else:
91
+ if top_p > 0.0:
92
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
93
+ if top_k > 0:
94
+ top_k = min(top_k, logits.size(-1)) # Safety check
95
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
96
+ if temperature != 1.0:
97
+ logits_top /= temperature
98
+ modify_logits_for_top_p_filtering(logits_top, top_p)
99
+ return indices[
100
+ torch.arange(indices.shape[0], device=indices.device),
101
+ torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
102
+ ]
103
+ else:
104
+ if min_p > 0.0:
105
+ logits_top = logits.clone()
106
+ max_prob = logits_top[..., 0].item()
107
+ min_prob = max_prob * min_p
108
+ modify_logits_for_min_p_filtering(logits_top, min_prob)
109
+ if temperature != 1.0:
110
+ logits_top /= temperature
111
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
112
+ # Clone so that when we modify for top_p we don't change the original logits
113
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
114
+ modify_logits_for_top_p_filtering(logits_top, top_p)
115
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
116
+ dim=-1
117
+ )
118
+
119
+
120
+ @torch.inference_mode()
121
+ def decode(
122
+ input_ids,
123
+ model,
124
+ max_length,
125
+ top_k=1,
126
+ top_p=0.0,
127
+ min_p=0.0,
128
+ temperature=1.0,
129
+ repetition_penalty=1.0,
130
+ eos_token_id=None,
131
+ teacher_outputs=None,
132
+ vocab_size=None,
133
+ cg=False,
134
+ enable_timing=False,
135
+ streamer: Optional[TextStreamer] = None
136
+ ):
137
+ """Decoding, either greedy or with top-k or top-p sampling.
138
+ If top-k = 0, don't limit the number of candidates (pure sampling).
139
+ Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
140
+ then top-p.
141
+ We assume that all sequences in the same batch have the same length.
142
+
143
+ Arguments:
144
+ input_ids: (batch, seq_len)
145
+ max_length: int
146
+ teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
147
+ logits, the next token is taken from the teacher_outputs. Useful for testing.
148
+ Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
149
+ sequences: (batch, max_length)
150
+ scores: tuples of (batch, vocab_size)
151
+ """
152
+ if streamer is not None:
153
+ streamer.put(input_ids.cpu())
154
+
155
+ batch_size, seqlen_og = input_ids.shape
156
+ teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
157
+ if cg:
158
+ if not hasattr(model, "_decoding_cache"):
159
+ model._decoding_cache = None
160
+ model._decoding_cache = update_graph_cache(
161
+ model,
162
+ model._decoding_cache,
163
+ batch_size,
164
+ seqlen_og,
165
+ max_length,
166
+ )
167
+ inference_params = model._decoding_cache.inference_params
168
+ inference_params.reset(max_length, batch_size)
169
+ else:
170
+ inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
171
+
172
+ def get_logits(input_ids, inference_params):
173
+ decoding = inference_params.seqlen_offset > 0
174
+ if decoding:
175
+ position_ids = torch.full(
176
+ (batch_size, 1),
177
+ inference_params.seqlen_offset,
178
+ dtype=torch.long,
179
+ device=input_ids.device,
180
+ )
181
+ else:
182
+ position_ids = None
183
+ if not cg or not decoding:
184
+ logits = model(
185
+ input_ids,
186
+ position_ids=position_ids,
187
+ inference_params=inference_params,
188
+ num_last_tokens=1,
189
+ ).logits.squeeze(dim=1)
190
+ else:
191
+ logits = model._decoding_cache.run(
192
+ input_ids, position_ids, inference_params.seqlen_offset
193
+ ).squeeze(dim=1)
194
+ return logits[..., :vocab_size] if vocab_size is not None else logits
195
+
196
+ def sample_tokens(logits, inference_params):
197
+ if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
198
+ token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
199
+ else:
200
+ token = teacher_outputs[:, inference_params.seqlen_offset]
201
+ # return rearrange(token, "b -> b 1")
202
+ return token.unsqueeze(1)
203
+
204
+ def should_stop(current_token, inference_params):
205
+ if inference_params.seqlen_offset == 0:
206
+ return False
207
+ if eos_token_id is not None and (current_token == eos_token_id).all():
208
+ return True
209
+ if inference_params.seqlen_offset >= max_length - 1:
210
+ return True
211
+ return False
212
+
213
+ start = torch.cuda.Event(enable_timing=enable_timing)
214
+ end = torch.cuda.Event(enable_timing=enable_timing)
215
+
216
+ if enable_timing:
217
+ start.record()
218
+ scores, sequences = [], [input_ids]
219
+ sequences_cat = input_ids
220
+ while not should_stop(sequences[-1], inference_params):
221
+ scores.append(get_logits(sequences[-1], inference_params))
222
+ inference_params.seqlen_offset += sequences[-1].shape[1]
223
+ if repetition_penalty == 1.0:
224
+ sampled_tokens = sample_tokens(scores[-1], inference_params)
225
+ else:
226
+ logits = modify_logit_for_repetition_penalty(
227
+ scores[-1].clone(), sequences_cat, repetition_penalty
228
+ )
229
+ sampled_tokens = sample_tokens(logits, inference_params)
230
+ sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
231
+ sequences.append(sampled_tokens)
232
+ if streamer is not None:
233
+ streamer.put(sampled_tokens.cpu())
234
+ if streamer is not None:
235
+ streamer.end()
236
+ if enable_timing:
237
+ end.record()
238
+ torch.cuda.synchronize()
239
+ print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
240
+ output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
241
+ return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
242
+
243
+
244
+ class GenerationMixin:
245
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
246
+ raise NotImplementedError
247
+
248
+ def generate(
249
+ self,
250
+ input_ids,
251
+ max_length,
252
+ top_k=1,
253
+ top_p=0.0,
254
+ min_p=0.0,
255
+ temperature=1.0,
256
+ return_dict_in_generate=False,
257
+ output_scores=False,
258
+ **kwargs,
259
+ ):
260
+ output = decode(
261
+ input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
262
+ )
263
+ if not output_scores:
264
+ output.scores = None
265
+ return output if return_dict_in_generate else output.sequences
266
+
267
+
268
+ @dataclass
269
+ class DecodingCGCache:
270
+ max_batch_size: int = 0
271
+ max_seqlen: int = 0
272
+ device = None
273
+ dtype = None
274
+ callables: dict = field(default_factory=dict)
275
+ mempool = None
276
+ inference_params: Optional[InferenceParams] = None
277
+ run: Optional[Callable] = None
278
+
279
+
280
+ @torch.inference_mode()
281
+ def update_graph_cache(
282
+ model,
283
+ cache,
284
+ batch_size,
285
+ seqlen_og,
286
+ max_seqlen,
287
+ decoding_seqlens=(1,),
288
+ dtype=None,
289
+ n_warmups=2,
290
+ ):
291
+ if cache is None:
292
+ cache = DecodingCGCache()
293
+ param_example = next(iter(model.parameters()))
294
+ device = param_example.device
295
+ if dtype is None:
296
+ dtype = param_example.dtype
297
+ if (
298
+ (device, dtype) != (cache.device, cache.dtype)
299
+ or batch_size > cache.max_batch_size
300
+ or max_seqlen > cache.max_seqlen
301
+ ): # Invalidate the cache
302
+ cache.callables = {}
303
+ cache.mempool = None
304
+ cache.inference_params = None
305
+ gc.collect()
306
+ cache.device, cache.dtype = device, dtype
307
+ cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
308
+ assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
309
+ inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
310
+ lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
311
+ cache.inference_params = InferenceParams(
312
+ max_seqlen=max_seqlen,
313
+ max_batch_size=batch_size,
314
+ seqlen_offset=seqlen_og,
315
+ key_value_memory_dict=inf_cache,
316
+ lengths_per_sample=lengths_per_sample,
317
+ )
318
+ cache.mempool = torch.cuda.graphs.graph_pool_handle()
319
+ for decoding_seqlen in decoding_seqlens:
320
+ if (batch_size, decoding_seqlen) not in cache.callables:
321
+ cache.callables[batch_size, decoding_seqlen] = capture_graph(
322
+ model,
323
+ cache.inference_params,
324
+ batch_size,
325
+ max_seqlen,
326
+ decoding_seqlen=decoding_seqlen,
327
+ mempool=cache.mempool,
328
+ n_warmups=n_warmups,
329
+ )
330
+
331
+ def dispatch(input_ids, position_ids, seqlen):
332
+ batch_size, decoding_seqlen = input_ids.shape[:2]
333
+ return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
334
+
335
+ cache.run = dispatch
336
+ cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
337
+ return cache
338
+
339
+
340
+ def capture_graph(
341
+ model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
342
+ ):
343
+ device = next(iter(model.parameters())).device
344
+ input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
345
+ position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
346
+ seqlen_offset_og = inference_params.seqlen_offset
347
+ inference_params.seqlen_offset = max_seqlen - decoding_seqlen
348
+ inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
349
+
350
+ # Warmup before capture
351
+ s = torch.cuda.Stream()
352
+ s.wait_stream(torch.cuda.current_stream())
353
+ with torch.cuda.stream(s):
354
+ for _ in range(n_warmups):
355
+ logits = model(
356
+ input_ids,
357
+ position_ids=position_ids,
358
+ inference_params=inference_params,
359
+ num_last_tokens=decoding_seqlen,
360
+ ).logits
361
+ s.synchronize()
362
+ # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
363
+ # which requires that graph launch and non-captured launch to not overlap (I think,
364
+ # that's how I interpret the documentation). I'm not sure if this is required.
365
+ if torch.distributed.is_initialized():
366
+ torch.distributed.barrier()
367
+ torch.cuda.current_stream().wait_stream(s)
368
+ # Captures the graph
369
+ # To allow capture, automatically sets a side stream as the current stream in the context
370
+ graph = torch.cuda.CUDAGraph()
371
+ with torch.cuda.graph(graph, pool=mempool):
372
+ logits = model(
373
+ input_ids,
374
+ position_ids=position_ids,
375
+ inference_params=inference_params,
376
+ num_last_tokens=decoding_seqlen,
377
+ ).logits
378
+
379
+ def run(new_input_ids, new_position_ids, seqlen):
380
+ inference_params.lengths_per_sample[:] = seqlen
381
+ input_ids.copy_(new_input_ids)
382
+ position_ids.copy_(new_position_ids)
383
+ graph.replay()
384
+ return logits.clone()
385
+
386
+ inference_params.seqlen_offset = seqlen_offset_og
387
+ return run
mamba/build/lib/mamba_ssm/utils/hf.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import torch
4
+
5
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
6
+ from transformers.utils.hub import cached_file
7
+
8
+
9
+ def load_config_hf(model_name):
10
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
11
+ return json.load(open(resolved_archive_file))
12
+
13
+
14
+ def load_state_dict_hf(model_name, device=None, dtype=None):
15
+ # If not fp32, then we don't want to load directly to the GPU
16
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
17
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
18
+ return torch.load(resolved_archive_file, map_location=mapped_device)
19
+ # Convert dtype before moving to GPU to save memory
20
+ if dtype is not None:
21
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
22
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
23
+ return state_dict
mamba/csrc/selective_scan/reverse_scan.cuh ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #ifndef USE_ROCM
8
+ #include <cub/config.cuh>
9
+
10
+ #include <cub/util_ptx.cuh>
11
+ #include <cub/util_type.cuh>
12
+ #include <cub/block/block_raking_layout.cuh>
13
+ // #include <cub/detail/uninitialized_copy.cuh>
14
+ #else
15
+ #include <hipcub/hipcub.hpp>
16
+ namespace cub = hipcub;
17
+ #endif
18
+ #include "uninitialized_copy.cuh"
19
+
20
+ /**
21
+ * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned.
22
+ */
23
+ template <
24
+ int LENGTH,
25
+ typename T,
26
+ typename ReductionOp>
27
+ __device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
28
+ static_assert(LENGTH > 0);
29
+ T retval = input[LENGTH - 1];
30
+ #pragma unroll
31
+ for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
32
+ return retval;
33
+ }
34
+
35
+ /**
36
+ * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
37
+ */
38
+ template <
39
+ int LENGTH,
40
+ typename T,
41
+ typename ScanOp>
42
+ __device__ __forceinline__ T ThreadReverseScanInclusive(
43
+ const T (&input)[LENGTH],
44
+ T (&output)[LENGTH],
45
+ ScanOp scan_op,
46
+ const T postfix)
47
+ {
48
+ T inclusive = postfix;
49
+ #pragma unroll
50
+ for (int i = LENGTH - 1; i >= 0; --i) {
51
+ inclusive = scan_op(inclusive, input[i]);
52
+ output[i] = inclusive;
53
+ }
54
+ return inclusive;
55
+ }
56
+
57
+ /**
58
+ * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
59
+ */
60
+ template <
61
+ int LENGTH,
62
+ typename T,
63
+ typename ScanOp>
64
+ __device__ __forceinline__ T ThreadReverseScanExclusive(
65
+ const T (&input)[LENGTH],
66
+ T (&output)[LENGTH],
67
+ ScanOp scan_op,
68
+ const T postfix)
69
+ {
70
+ // Careful, output maybe be aliased to input
71
+ T exclusive = postfix;
72
+ T inclusive;
73
+ #pragma unroll
74
+ for (int i = LENGTH - 1; i >= 0; --i) {
75
+ inclusive = scan_op(exclusive, input[i]);
76
+ output[i] = exclusive;
77
+ exclusive = inclusive;
78
+ }
79
+ return inclusive;
80
+ }
81
+
82
+
83
+ /**
84
+ * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
85
+ *
86
+ * LOGICAL_WARP_THREADS must be a power-of-two
87
+ */
88
+ template <
89
+ typename T, ///< Data type being scanned
90
+ int LOGICAL_WARP_THREADS ///< Number of threads per logical warp
91
+ >
92
+ struct WarpReverseScan {
93
+ //---------------------------------------------------------------------
94
+ // Constants and type definitions
95
+ //---------------------------------------------------------------------
96
+
97
+ /// Whether the logical warp size and the PTX warp size coincide
98
+
99
+ // In hipcub, warp_threads is defined as HIPCUB_WARP_THREADS ::rocprim::warp_size()
100
+ // While in cub, it's defined as a macro that takes a redundant unused argument.
101
+ #ifndef USE_ROCM
102
+ #define WARP_THREADS CUB_WARP_THREADS(0)
103
+ #else
104
+ #define WARP_THREADS HIPCUB_WARP_THREADS
105
+ #endif
106
+ static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS);
107
+ /// The number of warp scan steps
108
+ static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
109
+ static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
110
+
111
+
112
+ //---------------------------------------------------------------------
113
+ // Thread fields
114
+ //---------------------------------------------------------------------
115
+
116
+ /// Lane index in logical warp
117
+ unsigned int lane_id;
118
+
119
+ /// Logical warp index in 32-thread physical warp
120
+ unsigned int warp_id;
121
+
122
+ /// 32-thread physical warp member mask of logical warp
123
+ unsigned int member_mask;
124
+
125
+ //---------------------------------------------------------------------
126
+ // Construction
127
+ //---------------------------------------------------------------------
128
+
129
+ /// Constructor
130
+ explicit __device__ __forceinline__
131
+ WarpReverseScan()
132
+ : lane_id(cub::LaneId())
133
+ , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
134
+ , member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
135
+ {
136
+ if (!IS_ARCH_WARP) {
137
+ lane_id = lane_id % LOGICAL_WARP_THREADS;
138
+ }
139
+ }
140
+
141
+
142
+ /// Broadcast
143
+ __device__ __forceinline__ T Broadcast(
144
+ T input, ///< [in] The value to broadcast
145
+ int src_lane) ///< [in] Which warp lane is to do the broadcasting
146
+ {
147
+ return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
148
+ }
149
+
150
+
151
+ /// Inclusive scan
152
+ template <typename ScanOpT>
153
+ __device__ __forceinline__ void InclusiveReverseScan(
154
+ T input, ///< [in] Calling thread's input item.
155
+ T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
156
+ ScanOpT scan_op) ///< [in] Binary scan operator
157
+ {
158
+ inclusive_output = input;
159
+ #pragma unroll
160
+ for (int STEP = 0; STEP < STEPS; STEP++) {
161
+ int offset = 1 << STEP;
162
+ T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
163
+ inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
164
+ );
165
+ // Perform scan op if from a valid peer
166
+ inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
167
+ ? inclusive_output : scan_op(temp, inclusive_output);
168
+ }
169
+ }
170
+
171
+ /// Exclusive scan
172
+ // Get exclusive from inclusive
173
+ template <typename ScanOpT>
174
+ __device__ __forceinline__ void ExclusiveReverseScan(
175
+ T input, ///< [in] Calling thread's input item.
176
+ T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
177
+ ScanOpT scan_op, ///< [in] Binary scan operator
178
+ T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
179
+ {
180
+ T inclusive_output;
181
+ InclusiveReverseScan(input, inclusive_output, scan_op);
182
+ warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
183
+ // initial value unknown
184
+ exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
185
+ inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
186
+ );
187
+ }
188
+
189
+ /**
190
+ * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
191
+ */
192
+ template <typename ScanOpT>
193
+ __device__ __forceinline__ void ReverseScan(
194
+ T input, ///< [in] Calling thread's input item.
195
+ T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item.
196
+ T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item.
197
+ ScanOpT scan_op) ///< [in] Binary scan operator
198
+ {
199
+ InclusiveReverseScan(input, inclusive_output, scan_op);
200
+ // initial value unknown
201
+ exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
202
+ inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
203
+ );
204
+ }
205
+
206
+ };
207
+
208
+ /**
209
+ * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
210
+ */
211
+ template <
212
+ typename T, ///< Data type being scanned
213
+ int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension
214
+ bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
215
+ >
216
+ struct BlockReverseScan {
217
+ //---------------------------------------------------------------------
218
+ // Types and constants
219
+ //---------------------------------------------------------------------
220
+
221
+ /// Constants
222
+ /// The thread block size in threads
223
+ static constexpr int BLOCK_THREADS = BLOCK_DIM_X;
224
+
225
+ /// Layout type for padded thread block raking grid
226
+ using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
227
+ // The number of reduction elements is not a multiple of the number of raking threads for now
228
+ static_assert(BlockRakingLayout::UNGUARDED);
229
+
230
+ /// Number of raking threads
231
+ static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
232
+ /// Number of raking elements per warp synchronous raking thread
233
+ static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
234
+ /// Cooperative work can be entirely warp synchronous
235
+ static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));
236
+
237
+ /// WarpReverseScan utility type
238
+ using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;
239
+
240
+ /// Shared memory storage layout type
241
+ struct _TempStorage {
242
+ typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid
243
+ };
244
+
245
+
246
+ /// Alias wrapper allowing storage to be unioned
247
+ struct TempStorage : cub::Uninitialized<_TempStorage> {};
248
+
249
+
250
+ //---------------------------------------------------------------------
251
+ // Per-thread fields
252
+ //---------------------------------------------------------------------
253
+
254
+ // Thread fields
255
+ _TempStorage &temp_storage;
256
+ unsigned int linear_tid;
257
+ T cached_segment[SEGMENT_LENGTH];
258
+
259
+
260
+ //---------------------------------------------------------------------
261
+ // Utility methods
262
+ //---------------------------------------------------------------------
263
+
264
+ /// Performs upsweep raking reduction, returning the aggregate
265
+ template <typename ScanOp>
266
+ __device__ __forceinline__ T Upsweep(ScanOp scan_op) {
267
+ T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
268
+ // Read data into registers
269
+ #pragma unroll
270
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
271
+ T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
272
+ #pragma unroll
273
+ for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
274
+ raking_partial = scan_op(raking_partial, cached_segment[i]);
275
+ }
276
+ return raking_partial;
277
+ }
278
+
279
+
280
+ /// Performs exclusive downsweep raking scan
281
+ template <typename ScanOp>
282
+ __device__ __forceinline__ void ExclusiveDownsweep(
283
+ ScanOp scan_op,
284
+ T raking_partial)
285
+ {
286
+ T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
287
+ // Read data back into registers
288
+ if (!MEMOIZE) {
289
+ #pragma unroll
290
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
291
+ }
292
+ ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
293
+ // Write data back to smem
294
+ #pragma unroll
295
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
296
+ }
297
+
298
+
299
+ //---------------------------------------------------------------------
300
+ // Constructors
301
+ //---------------------------------------------------------------------
302
+
303
+ /// Constructor
304
+ __device__ __forceinline__ BlockReverseScan(
305
+ TempStorage &temp_storage)
306
+ :
307
+ temp_storage(temp_storage.Alias()),
308
+ linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
309
+ {}
310
+
311
+
312
+ /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
313
+ template <
314
+ typename ScanOp,
315
+ typename BlockPostfixCallbackOp>
316
+ __device__ __forceinline__ void ExclusiveReverseScan(
317
+ T input, ///< [in] Calling thread's input item
318
+ T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input)
319
+ ScanOp scan_op, ///< [in] Binary scan operator
320
+ BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
321
+ {
322
+ if (WARP_SYNCHRONOUS) {
323
+ // Short-circuit directly to warp-synchronous scan
324
+ T block_aggregate;
325
+ WarpReverseScan warp_scan;
326
+ warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
327
+ // Obtain warp-wide postfix in lane0, then broadcast to other lanes
328
+ T block_postfix = block_postfix_callback_op(block_aggregate);
329
+ block_postfix = warp_scan.Broadcast(block_postfix, 0);
330
+ exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
331
+ } else {
332
+ // Place thread partial into shared memory raking grid
333
+ T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
334
+ detail::uninitialized_copy(placement_ptr, input);
335
+ cub::CTA_SYNC();
336
+ // Reduce parallelism down to just raking threads
337
+ if (linear_tid < RAKING_THREADS) {
338
+ WarpReverseScan warp_scan;
339
+ // Raking upsweep reduction across shared partials
340
+ T upsweep_partial = Upsweep(scan_op);
341
+ // Warp-synchronous scan
342
+ T exclusive_partial, block_aggregate;
343
+ warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
344
+ // Obtain block-wide postfix in lane0, then broadcast to other lanes
345
+ T block_postfix = block_postfix_callback_op(block_aggregate);
346
+ block_postfix = warp_scan.Broadcast(block_postfix, 0);
347
+ // Update postfix with warpscan exclusive partial
348
+ T downsweep_postfix = linear_tid == RAKING_THREADS - 1
349
+ ? block_postfix : scan_op(block_postfix, exclusive_partial);
350
+ // Exclusive raking downsweep scan
351
+ ExclusiveDownsweep(scan_op, downsweep_postfix);
352
+ }
353
+ cub::CTA_SYNC();
354
+ // Grab thread postfix from shared memory
355
+ exclusive_output = *placement_ptr;
356
+
357
+ // // Compute warp scan in each warp.
358
+ // // The exclusive output from the last lane in each warp is invalid.
359
+ // T inclusive_output;
360
+ // WarpReverseScan warp_scan;
361
+ // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);
362
+
363
+ // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid.
364
+ // T block_aggregate;
365
+ // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);
366
+
367
+ // // Apply warp postfix to our lane's partial
368
+ // if (warp_id != 0) {
369
+ // exclusive_output = scan_op(warp_postfix, exclusive_output);
370
+ // if (lane_id == 0) { exclusive_output = warp_postfix; }
371
+ // }
372
+
373
+ // // Use the first warp to determine the thread block postfix, returning the result in lane0
374
+ // if (warp_id == 0) {
375
+ // T block_postfix = block_postfix_callback_op(block_aggregate);
376
+ // if (lane_id == 0) {
377
+ // // Share the postfix with all threads
378
+ // detail::uninitialized_copy(&temp_storage.block_postfix,
379
+ // block_postfix);
380
+
381
+ // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
382
+ // }
383
+ // }
384
+
385
+ // cub::CTA_SYNC();
386
+
387
+ // // Incorporate thread block postfix into outputs
388
+ // T block_postfix = temp_storage.block_postfix;
389
+ // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
390
+ }
391
+ }
392
+
393
+
394
+ /**
395
+ * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
396
+ */
397
+ template <
398
+ int ITEMS_PER_THREAD,
399
+ typename ScanOp,
400
+ typename BlockPostfixCallbackOp>
401
+ __device__ __forceinline__ void InclusiveReverseScan(
402
+ T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
403
+ T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input)
404
+ ScanOp scan_op, ///< [in] Binary scan functor
405
+ BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
406
+ {
407
+ // Reduce consecutive thread items in registers
408
+ T thread_postfix = ThreadReverseReduce(input, scan_op);
409
+ // Exclusive thread block-scan
410
+ ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
411
+ // Inclusive scan in registers with postfix as seed
412
+ ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
413
+ }
414
+
415
+ };
mamba/csrc/selective_scan/selective_scan.cpp ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <torch/extension.h>
8
+ #include <vector>
9
+
10
+ #include "selective_scan.h"
11
+
12
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13
+
14
+ #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
15
+ if (ITYPE == at::ScalarType::Half) { \
16
+ using input_t = at::Half; \
17
+ __VA_ARGS__(); \
18
+ } else if (ITYPE == at::ScalarType::BFloat16) { \
19
+ using input_t = at::BFloat16; \
20
+ __VA_ARGS__(); \
21
+ } else if (ITYPE == at::ScalarType::Float) { \
22
+ using input_t = float; \
23
+ __VA_ARGS__(); \
24
+ } else { \
25
+ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
26
+ }
27
+
28
+ #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
29
+ if (WTYPE == at::ScalarType::Half) { \
30
+ using weight_t = at::Half; \
31
+ __VA_ARGS__(); \
32
+ } else if (WTYPE == at::ScalarType::BFloat16) { \
33
+ using weight_t = at::BFloat16; \
34
+ __VA_ARGS__(); \
35
+ } else if (WTYPE == at::ScalarType::Float) { \
36
+ using weight_t = float; \
37
+ __VA_ARGS__(); \
38
+ } else { \
39
+ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
40
+ }
41
+
42
+ #define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \
43
+ if (WTYPE == at::ScalarType::Float) { \
44
+ using weight_t = float; \
45
+ __VA_ARGS__(); \
46
+ } else if (WTYPE == at::ScalarType::ComplexFloat) { \
47
+ using weight_t = c10::complex<float>; \
48
+ __VA_ARGS__(); \
49
+ } else { \
50
+ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
51
+ }
52
+
53
+ template<typename input_t, typename weight_t>
54
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
55
+
56
+ template <typename input_t, typename weight_t>
57
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream);
58
+
59
+ void set_ssm_params_fwd(SSMParamsBase &params,
60
+ // sizes
61
+ const size_t batch,
62
+ const size_t dim,
63
+ const size_t seqlen,
64
+ const size_t dstate,
65
+ const size_t n_groups,
66
+ const size_t n_chunks,
67
+ const bool is_variable_B,
68
+ const bool is_variable_C,
69
+ // device pointers
70
+ const at::Tensor u,
71
+ const at::Tensor delta,
72
+ const at::Tensor A,
73
+ const at::Tensor B,
74
+ const at::Tensor C,
75
+ const at::Tensor out,
76
+ const at::Tensor z,
77
+ const at::Tensor out_z,
78
+ void* D_ptr,
79
+ void* delta_bias_ptr,
80
+ void* x_ptr,
81
+ bool has_z,
82
+ bool delta_softplus) {
83
+
84
+ // Reset the parameters
85
+ memset(&params, 0, sizeof(params));
86
+
87
+ params.batch = batch;
88
+ params.dim = dim;
89
+ params.seqlen = seqlen;
90
+ params.dstate = dstate;
91
+ params.n_groups = n_groups;
92
+ params.n_chunks = n_chunks;
93
+ params.dim_ngroups_ratio = dim / n_groups;
94
+
95
+ params.delta_softplus = delta_softplus;
96
+
97
+ params.is_variable_B = is_variable_B;
98
+ params.is_variable_C = is_variable_C;
99
+
100
+ // Set the pointers and strides.
101
+ params.u_ptr = u.data_ptr();
102
+ params.delta_ptr = delta.data_ptr();
103
+ params.A_ptr = A.data_ptr();
104
+ params.B_ptr = B.data_ptr();
105
+ params.C_ptr = C.data_ptr();
106
+ params.D_ptr = D_ptr;
107
+ params.delta_bias_ptr = delta_bias_ptr;
108
+ params.out_ptr = out.data_ptr();
109
+ params.x_ptr = x_ptr;
110
+ params.z_ptr = has_z ? z.data_ptr() : nullptr;
111
+ params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
112
+ // All stride are in elements, not bytes.
113
+ params.A_d_stride = A.stride(0);
114
+ params.A_dstate_stride = A.stride(1);
115
+ if (!is_variable_B) {
116
+ params.B_d_stride = B.stride(0);
117
+ } else {
118
+ params.B_batch_stride = B.stride(0);
119
+ params.B_group_stride = B.stride(1);
120
+ }
121
+ params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
122
+ if (!is_variable_C) {
123
+ params.C_d_stride = C.stride(0);
124
+ } else {
125
+ params.C_batch_stride = C.stride(0);
126
+ params.C_group_stride = C.stride(1);
127
+ }
128
+ params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
129
+ params.u_batch_stride = u.stride(0);
130
+ params.u_d_stride = u.stride(1);
131
+ params.delta_batch_stride = delta.stride(0);
132
+ params.delta_d_stride = delta.stride(1);
133
+ if (has_z) {
134
+ params.z_batch_stride = z.stride(0);
135
+ params.z_d_stride = z.stride(1);
136
+ params.out_z_batch_stride = out_z.stride(0);
137
+ params.out_z_d_stride = out_z.stride(1);
138
+ }
139
+ params.out_batch_stride = out.stride(0);
140
+ params.out_d_stride = out.stride(1);
141
+ }
142
+
143
+ void set_ssm_params_bwd(SSMParamsBwd &params,
144
+ // sizes
145
+ const size_t batch,
146
+ const size_t dim,
147
+ const size_t seqlen,
148
+ const size_t dstate,
149
+ const size_t n_groups,
150
+ const size_t n_chunks,
151
+ const bool is_variable_B,
152
+ const bool is_variable_C,
153
+ // device pointers
154
+ const at::Tensor u,
155
+ const at::Tensor delta,
156
+ const at::Tensor A,
157
+ const at::Tensor B,
158
+ const at::Tensor C,
159
+ const at::Tensor z,
160
+ const at::Tensor out,
161
+ const at::Tensor out_z,
162
+ void* D_ptr,
163
+ void* delta_bias_ptr,
164
+ void* x_ptr,
165
+ const at::Tensor dout,
166
+ const at::Tensor du,
167
+ const at::Tensor ddelta,
168
+ const at::Tensor dA,
169
+ const at::Tensor dB,
170
+ const at::Tensor dC,
171
+ const at::Tensor dz,
172
+ void* dD_ptr,
173
+ void* ddelta_bias_ptr,
174
+ bool has_z,
175
+ bool delta_softplus,
176
+ bool recompute_out_z) {
177
+ // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
178
+ set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
179
+ u, delta, A, B, C, has_z ? out : dout,
180
+ has_z ? z : dout,
181
+ // If not recompute_out_z, pass dout instead of out_z.
182
+ // This won't be used by the bwd kernel
183
+ recompute_out_z ? out_z : dout,
184
+ D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
185
+ if (!recompute_out_z) { params.out_z_ptr = nullptr; }
186
+
187
+ // Set the pointers and strides.
188
+ params.dout_ptr = dout.data_ptr();
189
+ params.du_ptr = du.data_ptr();
190
+ params.dA_ptr = dA.data_ptr();
191
+ params.dB_ptr = dB.data_ptr();
192
+ params.dC_ptr = dC.data_ptr();
193
+ params.dD_ptr = dD_ptr;
194
+ params.ddelta_ptr = ddelta.data_ptr();
195
+ params.ddelta_bias_ptr = ddelta_bias_ptr;
196
+ params.dz_ptr = has_z ? dz.data_ptr() : nullptr;
197
+ // All stride are in elements, not bytes.
198
+ params.dout_batch_stride = dout.stride(0);
199
+ params.dout_d_stride = dout.stride(1);
200
+ params.dA_d_stride = dA.stride(0);
201
+ params.dA_dstate_stride = dA.stride(1);
202
+ if (!is_variable_B) {
203
+ params.dB_d_stride = dB.stride(0);
204
+ } else {
205
+ params.dB_batch_stride = dB.stride(0);
206
+ params.dB_group_stride = dB.stride(1);
207
+ }
208
+ params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2);
209
+ if (!is_variable_C) {
210
+ params.dC_d_stride = dC.stride(0);
211
+ } else {
212
+ params.dC_batch_stride = dC.stride(0);
213
+ params.dC_group_stride = dC.stride(1);
214
+ }
215
+ params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2);
216
+ params.du_batch_stride = du.stride(0);
217
+ params.du_d_stride = du.stride(1);
218
+ params.ddelta_batch_stride = ddelta.stride(0);
219
+ params.ddelta_d_stride = ddelta.stride(1);
220
+ if (has_z) {
221
+ params.dz_batch_stride = dz.stride(0);
222
+ params.dz_d_stride = dz.stride(1);
223
+ }
224
+ }
225
+
226
+ std::vector<at::Tensor>
227
+ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
228
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
229
+ const c10::optional<at::Tensor> &D_,
230
+ const c10::optional<at::Tensor> &z_,
231
+ const c10::optional<at::Tensor> &delta_bias_,
232
+ bool delta_softplus) {
233
+ auto input_type = u.scalar_type();
234
+ auto weight_type = A.scalar_type();
235
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
236
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
237
+
238
+ const bool is_variable_B = B.dim() >= 3;
239
+ const bool is_variable_C = C.dim() >= 3;
240
+ const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
241
+
242
+ TORCH_CHECK(delta.scalar_type() == input_type);
243
+ TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
244
+ TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
245
+
246
+ TORCH_CHECK(u.is_cuda());
247
+ TORCH_CHECK(delta.is_cuda());
248
+ TORCH_CHECK(A.is_cuda());
249
+ TORCH_CHECK(B.is_cuda());
250
+ TORCH_CHECK(C.is_cuda());
251
+
252
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
253
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
254
+
255
+ const auto sizes = u.sizes();
256
+ const int batch_size = sizes[0];
257
+ const int dim = sizes[1];
258
+ const int seqlen = sizes[2];
259
+ const int dstate = A.size(1);
260
+ const int n_groups = is_variable_B ? B.size(1) : 1;
261
+
262
+ TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
263
+
264
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
265
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
266
+ CHECK_SHAPE(A, dim, dstate);
267
+ if (!is_variable_B) {
268
+ CHECK_SHAPE(B, dim, dstate);
269
+ } else {
270
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
271
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
272
+ }
273
+ if (!is_variable_C) {
274
+ CHECK_SHAPE(C, dim, dstate);
275
+ } else {
276
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
277
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
278
+ }
279
+
280
+ if (D_.has_value()) {
281
+ auto D = D_.value();
282
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
283
+ TORCH_CHECK(D.is_cuda());
284
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
285
+ CHECK_SHAPE(D, dim);
286
+ }
287
+
288
+ if (delta_bias_.has_value()) {
289
+ auto delta_bias = delta_bias_.value();
290
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
291
+ TORCH_CHECK(delta_bias.is_cuda());
292
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
293
+ CHECK_SHAPE(delta_bias, dim);
294
+ }
295
+
296
+ at::Tensor z, out_z;
297
+ const bool has_z = z_.has_value();
298
+ if (has_z) {
299
+ z = z_.value();
300
+ TORCH_CHECK(z.scalar_type() == input_type);
301
+ TORCH_CHECK(z.is_cuda());
302
+ TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
303
+ CHECK_SHAPE(z, batch_size, dim, seqlen);
304
+ out_z = torch::empty_like(z);
305
+ }
306
+
307
+ const int n_chunks = (seqlen + 2048 - 1) / 2048;
308
+ // const int n_chunks = (seqlen + 1024 - 1) / 1024;
309
+ // at::Tensor out = torch::empty_like(u);
310
+ // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
311
+ at::Tensor out = torch::empty_like(delta);
312
+ at::Tensor x;
313
+ x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
314
+
315
+ SSMParamsBase params;
316
+ set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
317
+ u, delta, A, B, C, out, z, out_z,
318
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
319
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
320
+ x.data_ptr(),
321
+ has_z,
322
+ delta_softplus);
323
+
324
+ // Otherwise the kernel will be launched from cuda:0 device
325
+ // Cast to char to avoid compiler warning about narrowing
326
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
327
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
328
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
329
+ DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] {
330
+ selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
331
+ });
332
+ });
333
+ std::vector<at::Tensor> result = {out, x};
334
+ if (has_z) { result.push_back(out_z); }
335
+ return result;
336
+ }
337
+
338
+ std::vector<at::Tensor>
339
+ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
340
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
341
+ const c10::optional<at::Tensor> &D_,
342
+ const c10::optional<at::Tensor> &z_,
343
+ const c10::optional<at::Tensor> &delta_bias_,
344
+ const at::Tensor &dout,
345
+ const c10::optional<at::Tensor> &x_,
346
+ const c10::optional<at::Tensor> &out_,
347
+ c10::optional<at::Tensor> &dz_,
348
+ bool delta_softplus,
349
+ bool recompute_out_z) {
350
+ auto input_type = u.scalar_type();
351
+ auto weight_type = A.scalar_type();
352
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
353
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
354
+
355
+ const bool is_variable_B = B.dim() >= 3;
356
+ const bool is_variable_C = C.dim() >= 3;
357
+ const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
358
+
359
+ TORCH_CHECK(delta.scalar_type() == input_type);
360
+ TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
361
+ TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
362
+ TORCH_CHECK(dout.scalar_type() == input_type);
363
+
364
+ TORCH_CHECK(u.is_cuda());
365
+ TORCH_CHECK(delta.is_cuda());
366
+ TORCH_CHECK(A.is_cuda());
367
+ TORCH_CHECK(B.is_cuda());
368
+ TORCH_CHECK(C.is_cuda());
369
+ TORCH_CHECK(dout.is_cuda());
370
+
371
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
372
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
373
+ TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
374
+
375
+ const auto sizes = u.sizes();
376
+ const int batch_size = sizes[0];
377
+ const int dim = sizes[1];
378
+ const int seqlen = sizes[2];
379
+ const int dstate = A.size(1);
380
+ const int n_groups = is_variable_B ? B.size(1) : 1;
381
+
382
+ TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
383
+
384
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
385
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
386
+ CHECK_SHAPE(A, dim, dstate);
387
+ if (!is_variable_B) {
388
+ CHECK_SHAPE(B, dim, dstate);
389
+ } else {
390
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
391
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
392
+ }
393
+ if (!is_variable_C) {
394
+ CHECK_SHAPE(C, dim, dstate);
395
+ } else {
396
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
397
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
398
+ }
399
+ CHECK_SHAPE(dout, batch_size, dim, seqlen);
400
+
401
+ if (D_.has_value()) {
402
+ auto D = D_.value();
403
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
404
+ TORCH_CHECK(D.is_cuda());
405
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
406
+ CHECK_SHAPE(D, dim);
407
+ }
408
+
409
+ if (delta_bias_.has_value()) {
410
+ auto delta_bias = delta_bias_.value();
411
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
412
+ TORCH_CHECK(delta_bias.is_cuda());
413
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
414
+ CHECK_SHAPE(delta_bias, dim);
415
+ }
416
+
417
+ at::Tensor z, out, dz, out_z;
418
+ const bool has_z = z_.has_value();
419
+ if (has_z) {
420
+ z = z_.value();
421
+ TORCH_CHECK(z.scalar_type() == input_type);
422
+ TORCH_CHECK(z.is_cuda());
423
+ TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
424
+ CHECK_SHAPE(z, batch_size, dim, seqlen);
425
+
426
+ TORCH_CHECK(out_.has_value());
427
+ out = out_.value();
428
+ TORCH_CHECK(out.scalar_type() == input_type);
429
+ TORCH_CHECK(out.is_cuda());
430
+ TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1);
431
+ CHECK_SHAPE(out, batch_size, dim, seqlen);
432
+
433
+ if (dz_.has_value()) {
434
+ dz = dz_.value();
435
+ TORCH_CHECK(dz.scalar_type() == input_type);
436
+ TORCH_CHECK(dz.is_cuda());
437
+ TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1);
438
+ CHECK_SHAPE(dz, batch_size, dim, seqlen);
439
+ } else {
440
+ dz = torch::empty_like(z);
441
+ }
442
+ if (recompute_out_z) {
443
+ out_z = torch::empty_like(out);
444
+ }
445
+ }
446
+
447
+ const int n_chunks = (seqlen + 2048 - 1) / 2048;
448
+ // const int n_chunks = (seqlen + 1024 - 1) / 1024;
449
+ if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
450
+ if (x_.has_value()) {
451
+ auto x = x_.value();
452
+ TORCH_CHECK(x.scalar_type() == weight_type);
453
+ TORCH_CHECK(x.is_cuda());
454
+ TORCH_CHECK(x.is_contiguous());
455
+ CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
456
+ }
457
+
458
+ at::Tensor du = torch::empty_like(u);
459
+ at::Tensor ddelta = torch::empty_like(delta);
460
+ at::Tensor dA = torch::zeros_like(A);
461
+ at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32));
462
+ at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32));
463
+ at::Tensor dD;
464
+ if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
465
+ at::Tensor ddelta_bias;
466
+ if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
467
+
468
+ SSMParamsBwd params;
469
+ set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
470
+ u, delta, A, B, C, z, out, out_z,
471
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
472
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
473
+ x_.has_value() ? x_.value().data_ptr() : nullptr,
474
+ dout, du, ddelta, dA, dB, dC, dz,
475
+ D_.has_value() ? dD.data_ptr() : nullptr,
476
+ delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
477
+ has_z, delta_softplus, recompute_out_z);
478
+
479
+ // Otherwise the kernel will be launched from cuda:0 device
480
+ // Cast to char to avoid compiler warning about narrowing
481
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
482
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
483
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
484
+ DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] {
485
+ selective_scan_bwd_cuda<input_t, weight_t>(params, stream);
486
+ });
487
+ });
488
+ std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
489
+ if (has_z) { result.push_back(dz); }
490
+ if (recompute_out_z) { result.push_back(out_z); }
491
+ return result;
492
+ }
493
+
494
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
495
+ m.def("fwd", &selective_scan_fwd, "Selective scan forward");
496
+ m.def("bwd", &selective_scan_bwd, "Selective scan backward");
497
+ }
mamba/csrc/selective_scan/selective_scan.h ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
8
+
9
+ struct SSMScanParamsBase {
10
+ using index_t = uint32_t;
11
+
12
+ int batch, seqlen, n_chunks;
13
+ index_t a_batch_stride;
14
+ index_t b_batch_stride;
15
+ index_t out_batch_stride;
16
+
17
+ // Common data pointers.
18
+ void *__restrict__ a_ptr;
19
+ void *__restrict__ b_ptr;
20
+ void *__restrict__ out_ptr;
21
+ void *__restrict__ x_ptr;
22
+ };
23
+
24
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
25
+
26
+ struct SSMParamsBase {
27
+ using index_t = uint32_t;
28
+
29
+ int batch, dim, seqlen, dstate, n_groups, n_chunks;
30
+ int dim_ngroups_ratio;
31
+ bool is_variable_B;
32
+ bool is_variable_C;
33
+
34
+ bool delta_softplus;
35
+
36
+ index_t A_d_stride;
37
+ index_t A_dstate_stride;
38
+ index_t B_batch_stride;
39
+ index_t B_d_stride;
40
+ index_t B_dstate_stride;
41
+ index_t B_group_stride;
42
+ index_t C_batch_stride;
43
+ index_t C_d_stride;
44
+ index_t C_dstate_stride;
45
+ index_t C_group_stride;
46
+ index_t u_batch_stride;
47
+ index_t u_d_stride;
48
+ index_t delta_batch_stride;
49
+ index_t delta_d_stride;
50
+ index_t z_batch_stride;
51
+ index_t z_d_stride;
52
+ index_t out_batch_stride;
53
+ index_t out_d_stride;
54
+ index_t out_z_batch_stride;
55
+ index_t out_z_d_stride;
56
+
57
+ // Common data pointers.
58
+ void *__restrict__ A_ptr;
59
+ void *__restrict__ B_ptr;
60
+ void *__restrict__ C_ptr;
61
+ void *__restrict__ D_ptr;
62
+ void *__restrict__ u_ptr;
63
+ void *__restrict__ delta_ptr;
64
+ void *__restrict__ delta_bias_ptr;
65
+ void *__restrict__ out_ptr;
66
+ void *__restrict__ x_ptr;
67
+ void *__restrict__ z_ptr;
68
+ void *__restrict__ out_z_ptr;
69
+ };
70
+
71
+ struct SSMParamsBwd: public SSMParamsBase {
72
+ index_t dout_batch_stride;
73
+ index_t dout_d_stride;
74
+ index_t dA_d_stride;
75
+ index_t dA_dstate_stride;
76
+ index_t dB_batch_stride;
77
+ index_t dB_group_stride;
78
+ index_t dB_d_stride;
79
+ index_t dB_dstate_stride;
80
+ index_t dC_batch_stride;
81
+ index_t dC_group_stride;
82
+ index_t dC_d_stride;
83
+ index_t dC_dstate_stride;
84
+ index_t du_batch_stride;
85
+ index_t du_d_stride;
86
+ index_t dz_batch_stride;
87
+ index_t dz_d_stride;
88
+ index_t ddelta_batch_stride;
89
+ index_t ddelta_d_stride;
90
+
91
+ // Common data pointers.
92
+ void *__restrict__ dout_ptr;
93
+ void *__restrict__ dA_ptr;
94
+ void *__restrict__ dB_ptr;
95
+ void *__restrict__ dC_ptr;
96
+ void *__restrict__ dD_ptr;
97
+ void *__restrict__ du_ptr;
98
+ void *__restrict__ dz_ptr;
99
+ void *__restrict__ ddelta_ptr;
100
+ void *__restrict__ ddelta_bias_ptr;
101
+ };
mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::BFloat16, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<float, float>(SSMParamsBwd &params, cudaStream_t stream);