Spaces:
Running
Running
Upload 28 files
Browse files- app.py +62 -0
- causal-conv1d/.github/workflows/publish.yaml +209 -0
- causal-conv1d/.gitignore +6 -0
- causal-conv1d/AUTHORS +1 -0
- causal-conv1d/LICENSE +29 -0
- causal-conv1d/README.md +43 -0
- causal-conv1d/build/lib/causal_conv1d/__init__.py +3 -0
- causal-conv1d/build/lib/causal_conv1d/causal_conv1d_interface.py +239 -0
- causal-conv1d/build/lib/causal_conv1d/causal_conv1d_varlen.py +86 -0
- causal-conv1d/causal_conv1d.egg-info/PKG-INFO +62 -0
- causal-conv1d/causal_conv1d.egg-info/SOURCES.txt +12 -0
- causal-conv1d/causal_conv1d.egg-info/dependency_links.txt +1 -0
- causal-conv1d/causal_conv1d.egg-info/requires.txt +3 -0
- causal-conv1d/causal_conv1d.egg-info/top_level.txt +1 -0
- causal-conv1d/causal_conv1d/__init__.py +3 -0
- causal-conv1d/causal_conv1d/causal_conv1d_interface.py +239 -0
- causal-conv1d/causal_conv1d/causal_conv1d_varlen.py +86 -0
- causal-conv1d/csrc/causal_conv1d.cpp +464 -0
- causal-conv1d/csrc/causal_conv1d.h +77 -0
- causal-conv1d/csrc/causal_conv1d_bwd.cu +627 -0
- causal-conv1d/csrc/causal_conv1d_common.h +98 -0
- causal-conv1d/csrc/causal_conv1d_fwd.cu +399 -0
- causal-conv1d/csrc/causal_conv1d_update.cu +130 -0
- causal-conv1d/csrc/static_switch.h +25 -0
- causal-conv1d/dist/causal_conv1d-1.4.0-py3.9.egg +0 -0
- causal-conv1d/rocm_patch/rocm6_0.patch +56 -0
- causal-conv1d/setup.py +296 -0
- causal-conv1d/tests/test_causal_conv1d.py +301 -0
app.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
import time
|
4 |
+
|
5 |
+
def generate_prompt(instruction, input=""):
|
6 |
+
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
7 |
+
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
|
8 |
+
if input:
|
9 |
+
return f"""Instruction: {instruction}
|
10 |
+
|
11 |
+
Input: {input}
|
12 |
+
|
13 |
+
Response:"""
|
14 |
+
else:
|
15 |
+
return f"""User: hi
|
16 |
+
|
17 |
+
Lover: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
18 |
+
|
19 |
+
User: {instruction}
|
20 |
+
|
21 |
+
Lover:"""
|
22 |
+
|
23 |
+
model_path = "models/rwkv-6-world-1b6/" # Path to your local model directory
|
24 |
+
|
25 |
+
model = AutoModelForCausalLM.from_pretrained(
|
26 |
+
model_path,
|
27 |
+
trust_remote_code=True,
|
28 |
+
use_flash_attention_2=False # Explicitly disable Flash Attention
|
29 |
+
).to(torch.float32)
|
30 |
+
|
31 |
+
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
33 |
+
model_path,
|
34 |
+
bos_token="</s>",
|
35 |
+
eos_token="</ s>",
|
36 |
+
unk_token="<unk>",
|
37 |
+
pad_token="<pad>",
|
38 |
+
trust_remote_code=True,
|
39 |
+
padding_side='left',
|
40 |
+
clean_up_tokenization_spaces=False # Or set to True if you prefer
|
41 |
+
)
|
42 |
+
|
43 |
+
print(tokenizer.special_tokens_map)
|
44 |
+
|
45 |
+
text = "Hi"
|
46 |
+
|
47 |
+
prompt = generate_prompt(text)
|
48 |
+
|
49 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
50 |
+
|
51 |
+
# Generate text word by word with stop sequence
|
52 |
+
generated_text = ""
|
53 |
+
for i in range(333): # Generate up to 333 tokens
|
54 |
+
output = model.generate(input_ids, max_new_tokens=1, do_sample=True, temperature=1.0, top_p=0.3, top_k=0)
|
55 |
+
new_word = tokenizer.decode(output[0][-1:], skip_special_tokens=True)
|
56 |
+
|
57 |
+
print(new_word, end="", flush=True) # Print word-by-word
|
58 |
+
generated_text += new_word
|
59 |
+
|
60 |
+
input_ids = output # Update input_ids for next iteration
|
61 |
+
|
62 |
+
print() # Add a newline at the end
|
causal-conv1d/.github/workflows/publish.yaml
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This workflow will:
|
2 |
+
# - Create a new Github release
|
3 |
+
# - Build wheels for supported architectures
|
4 |
+
# - Deploy the wheels to the Github release
|
5 |
+
# - Release the static code to PyPi
|
6 |
+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
7 |
+
|
8 |
+
name: Build wheels and deploy
|
9 |
+
|
10 |
+
on:
|
11 |
+
create:
|
12 |
+
tags:
|
13 |
+
- v*
|
14 |
+
|
15 |
+
jobs:
|
16 |
+
|
17 |
+
setup_release:
|
18 |
+
name: Create Release
|
19 |
+
runs-on: ubuntu-latest
|
20 |
+
steps:
|
21 |
+
- name: Get the tag version
|
22 |
+
id: extract_branch
|
23 |
+
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
24 |
+
shell: bash
|
25 |
+
|
26 |
+
- name: Create Release
|
27 |
+
id: create_release
|
28 |
+
uses: actions/create-release@v1
|
29 |
+
env:
|
30 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
31 |
+
with:
|
32 |
+
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
33 |
+
release_name: ${{ steps.extract_branch.outputs.branch }}
|
34 |
+
|
35 |
+
build_wheels:
|
36 |
+
name: Build Wheel
|
37 |
+
needs: setup_release
|
38 |
+
runs-on: ${{ matrix.os }}
|
39 |
+
|
40 |
+
strategy:
|
41 |
+
fail-fast: false
|
42 |
+
matrix:
|
43 |
+
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
|
44 |
+
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
|
45 |
+
os: [ubuntu-20.04]
|
46 |
+
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
|
47 |
+
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0.dev20240505']
|
48 |
+
cuda-version: ['11.8.0', '12.2.2']
|
49 |
+
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
|
50 |
+
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
|
51 |
+
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
|
52 |
+
# when building without C++11 ABI and using it on nvcr images.
|
53 |
+
cxx11_abi: ['FALSE', 'TRUE']
|
54 |
+
exclude:
|
55 |
+
# Pytorch < 2.2 does not support Python 3.12
|
56 |
+
- torch-version: '2.0.1'
|
57 |
+
python-version: '3.12'
|
58 |
+
- torch-version: '2.1.2'
|
59 |
+
python-version: '3.12'
|
60 |
+
# Pytorch <= 2.0 only supports CUDA <= 11.8
|
61 |
+
- torch-version: '2.0.1'
|
62 |
+
cuda-version: '12.2.2'
|
63 |
+
|
64 |
+
steps:
|
65 |
+
- name: Checkout
|
66 |
+
uses: actions/checkout@v3
|
67 |
+
|
68 |
+
- name: Set up Python
|
69 |
+
uses: actions/setup-python@v4
|
70 |
+
with:
|
71 |
+
python-version: ${{ matrix.python-version }}
|
72 |
+
|
73 |
+
- name: Set CUDA and PyTorch versions
|
74 |
+
run: |
|
75 |
+
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
|
76 |
+
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
|
77 |
+
|
78 |
+
- name: Free up disk space
|
79 |
+
if: ${{ runner.os == 'Linux' }}
|
80 |
+
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
|
81 |
+
# https://github.com/easimon/maximize-build-space/tree/test-report
|
82 |
+
run: |
|
83 |
+
sudo rm -rf /usr/share/dotnet
|
84 |
+
sudo rm -rf /opt/ghc
|
85 |
+
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
86 |
+
|
87 |
+
- name: Set up swap space
|
88 |
+
if: runner.os == 'Linux'
|
89 |
+
uses: pierotofy/[email protected]
|
90 |
+
with:
|
91 |
+
swap-size-gb: 10
|
92 |
+
|
93 |
+
- name: Install CUDA ${{ matrix.cuda-version }}
|
94 |
+
if: ${{ matrix.cuda-version != 'cpu' }}
|
95 |
+
uses: Jimver/[email protected]
|
96 |
+
id: cuda-toolkit
|
97 |
+
with:
|
98 |
+
cuda: ${{ matrix.cuda-version }}
|
99 |
+
linux-local-args: '["--toolkit"]'
|
100 |
+
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
|
101 |
+
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
|
102 |
+
method: 'network'
|
103 |
+
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
|
104 |
+
# not just nvcc
|
105 |
+
# sub-packages: '["nvcc"]'
|
106 |
+
|
107 |
+
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
|
108 |
+
run: |
|
109 |
+
pip install --upgrade pip
|
110 |
+
# If we don't install before installing Pytorch, we get error for torch 2.0.1
|
111 |
+
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
|
112 |
+
pip install lit
|
113 |
+
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
|
114 |
+
pip install setuptools==68.0.0
|
115 |
+
# We want to figure out the CUDA version to download pytorch
|
116 |
+
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
|
117 |
+
# This code is ugly, maybe there's a better way to do this.
|
118 |
+
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
|
119 |
+
minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
|
120 |
+
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
|
121 |
+
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
|
122 |
+
)
|
123 |
+
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
|
124 |
+
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
|
125 |
+
else
|
126 |
+
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
|
127 |
+
fi
|
128 |
+
nvcc --version
|
129 |
+
python --version
|
130 |
+
python -c "import torch; print('PyTorch:', torch.__version__)"
|
131 |
+
python -c "import torch; print('CUDA:', torch.version.cuda)"
|
132 |
+
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
133 |
+
shell:
|
134 |
+
bash
|
135 |
+
|
136 |
+
- name: Build wheel
|
137 |
+
run: |
|
138 |
+
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
|
139 |
+
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
|
140 |
+
# However this still fails so I'm using a newer version of setuptools
|
141 |
+
pip install setuptools==68.0.0
|
142 |
+
pip install ninja packaging wheel
|
143 |
+
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
|
144 |
+
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
145 |
+
# Limit MAX_JOBS otherwise the github runner goes OOM
|
146 |
+
MAX_JOBS=2 CAUSAL_CONV1D_FORCE_BUILD="TRUE" CAUSAL_CONV1D_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
|
147 |
+
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
|
148 |
+
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
|
149 |
+
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
|
150 |
+
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
151 |
+
|
152 |
+
- name: Log Built Wheels
|
153 |
+
run: |
|
154 |
+
ls dist
|
155 |
+
|
156 |
+
- name: Get the tag version
|
157 |
+
id: extract_branch
|
158 |
+
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
159 |
+
|
160 |
+
- name: Get Release with tag
|
161 |
+
id: get_current_release
|
162 |
+
uses: joutvhu/get-release@v1
|
163 |
+
with:
|
164 |
+
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
165 |
+
env:
|
166 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
167 |
+
|
168 |
+
- name: Upload Release Asset
|
169 |
+
id: upload_release_asset
|
170 |
+
uses: actions/upload-release-asset@v1
|
171 |
+
env:
|
172 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
173 |
+
with:
|
174 |
+
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
|
175 |
+
asset_path: ./dist/${{env.wheel_name}}
|
176 |
+
asset_name: ${{env.wheel_name}}
|
177 |
+
asset_content_type: application/*
|
178 |
+
|
179 |
+
publish_package:
|
180 |
+
name: Publish package
|
181 |
+
needs: [build_wheels]
|
182 |
+
|
183 |
+
runs-on: ubuntu-latest
|
184 |
+
|
185 |
+
steps:
|
186 |
+
- uses: actions/checkout@v3
|
187 |
+
|
188 |
+
- uses: actions/setup-python@v4
|
189 |
+
with:
|
190 |
+
python-version: '3.10'
|
191 |
+
|
192 |
+
- name: Install dependencies
|
193 |
+
run: |
|
194 |
+
pip install ninja packaging setuptools wheel twine
|
195 |
+
# We don't want to download anything CUDA-related here
|
196 |
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
197 |
+
|
198 |
+
- name: Build core package
|
199 |
+
env:
|
200 |
+
CAUSAL_CONV1D_SKIP_CUDA_BUILD: "TRUE"
|
201 |
+
run: |
|
202 |
+
python setup.py sdist --dist-dir=dist
|
203 |
+
|
204 |
+
- name: Deploy
|
205 |
+
env:
|
206 |
+
TWINE_USERNAME: "__token__"
|
207 |
+
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
208 |
+
run: |
|
209 |
+
python -m twine upload dist/*
|
causal-conv1d/.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*__pycache__/
|
2 |
+
*.egg-info/
|
3 |
+
build/
|
4 |
+
**.so
|
5 |
+
*.hip
|
6 |
+
*_hip.*
|
causal-conv1d/AUTHORS
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Tri Dao, [email protected]
|
causal-conv1d/LICENSE
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without
|
7 |
+
modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
* Redistributions of source code must retain the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
this list of conditions and the following disclaimer in the documentation
|
14 |
+
and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
* Neither the name of the copyright holder nor the names of its
|
17 |
+
contributors may be used to endorse or promote products derived from
|
18 |
+
this software without specific prior written permission.
|
19 |
+
|
20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
causal-conv1d/README.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Causal depthwise conv1d in CUDA with a PyTorch interface
|
2 |
+
|
3 |
+
Features:
|
4 |
+
- Support fp32, fp16, bf16.
|
5 |
+
- Kernel size 2, 3, 4.
|
6 |
+
|
7 |
+
## How to use
|
8 |
+
|
9 |
+
```
|
10 |
+
from causal_conv1d import causal_conv1d_fn
|
11 |
+
```
|
12 |
+
|
13 |
+
```
|
14 |
+
def causal_conv1d_fn(x, weight, bias=None, activation=None):
|
15 |
+
"""
|
16 |
+
x: (batch, dim, seqlen)
|
17 |
+
weight: (dim, width)
|
18 |
+
bias: (dim,)
|
19 |
+
activation: either None or "silu" or "swish"
|
20 |
+
|
21 |
+
out: (batch, dim, seqlen)
|
22 |
+
"""
|
23 |
+
```
|
24 |
+
|
25 |
+
Equivalent to:
|
26 |
+
```
|
27 |
+
import torch.nn.functional as F
|
28 |
+
|
29 |
+
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
30 |
+
```
|
31 |
+
|
32 |
+
## Additional Prerequisites for AMD cards
|
33 |
+
|
34 |
+
### Patching ROCm
|
35 |
+
|
36 |
+
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
|
37 |
+
|
38 |
+
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
|
39 |
+
|
40 |
+
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
|
41 |
+
```bash
|
42 |
+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
|
43 |
+
```
|
causal-conv1d/build/lib/causal_conv1d/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "1.4.0"
|
2 |
+
|
3 |
+
from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
|
causal-conv1d/build/lib/causal_conv1d/causal_conv1d_interface.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
import causal_conv1d_cuda
|
8 |
+
|
9 |
+
|
10 |
+
class CausalConv1dFn(torch.autograd.Function):
|
11 |
+
@staticmethod
|
12 |
+
def forward(
|
13 |
+
ctx,
|
14 |
+
x,
|
15 |
+
weight,
|
16 |
+
bias=None,
|
17 |
+
seq_idx=None,
|
18 |
+
initial_states=None,
|
19 |
+
return_final_states=False,
|
20 |
+
final_states_out=None,
|
21 |
+
activation=None,
|
22 |
+
):
|
23 |
+
if activation not in [None, "silu", "swish"]:
|
24 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
25 |
+
if x.stride(2) != 1 and x.stride(1) != 1:
|
26 |
+
x = x.contiguous()
|
27 |
+
bias = bias.contiguous() if bias is not None else None
|
28 |
+
if seq_idx is not None:
|
29 |
+
assert (
|
30 |
+
initial_states is None
|
31 |
+
), "initial_states must be None if seq_idx is not None"
|
32 |
+
assert (
|
33 |
+
not return_final_states
|
34 |
+
), "If seq_idx is not None, we don't return final_states_out"
|
35 |
+
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
36 |
+
if initial_states is not None and (
|
37 |
+
initial_states.stride(2) != 1 and initial_states.stride(1) != 1
|
38 |
+
):
|
39 |
+
initial_states = initial_states.contiguous()
|
40 |
+
if return_final_states:
|
41 |
+
assert (
|
42 |
+
x.stride(1) == 1
|
43 |
+
), "Only channel-last layout support returning final_states_out"
|
44 |
+
if final_states_out is not None:
|
45 |
+
assert (
|
46 |
+
final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
batch, dim, seqlen = x.shape
|
50 |
+
width = weight.shape[1]
|
51 |
+
final_states_out = torch.empty(
|
52 |
+
batch, width - 1, dim, device=x.device, dtype=x.dtype
|
53 |
+
).transpose(1, 2)
|
54 |
+
else:
|
55 |
+
final_states_out = None
|
56 |
+
ctx.activation = activation in ["silu", "swish"]
|
57 |
+
out = causal_conv1d_cuda.causal_conv1d_fwd(
|
58 |
+
x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
|
59 |
+
)
|
60 |
+
ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
|
61 |
+
ctx.return_final_states = return_final_states
|
62 |
+
ctx.return_dinitial_states = (
|
63 |
+
initial_states is not None and initial_states.requires_grad
|
64 |
+
)
|
65 |
+
return out if not return_final_states else (out, final_states_out)
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def backward(ctx, dout, *args):
|
69 |
+
x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
|
70 |
+
dfinal_states = args[0] if ctx.return_final_states else None
|
71 |
+
if dout.stride(2) != 1 and dout.stride(1) != 1:
|
72 |
+
dout = dout.contiguous()
|
73 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
74 |
+
# backward of conv1d with the backward of chunk).
|
75 |
+
# Here we just pass in None and dx will be allocated in the C++ code.
|
76 |
+
dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd(
|
77 |
+
x,
|
78 |
+
weight,
|
79 |
+
bias,
|
80 |
+
dout,
|
81 |
+
seq_idx,
|
82 |
+
initial_states,
|
83 |
+
dfinal_states,
|
84 |
+
None,
|
85 |
+
ctx.return_dinitial_states,
|
86 |
+
ctx.activation,
|
87 |
+
)
|
88 |
+
return (
|
89 |
+
dx,
|
90 |
+
dweight,
|
91 |
+
dbias if bias is not None else None,
|
92 |
+
None,
|
93 |
+
dinitial_states if initial_states is not None else None,
|
94 |
+
None,
|
95 |
+
None,
|
96 |
+
None,
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
def causal_conv1d_fn(
|
101 |
+
x,
|
102 |
+
weight,
|
103 |
+
bias=None,
|
104 |
+
seq_idx=None,
|
105 |
+
initial_states=None,
|
106 |
+
return_final_states=False,
|
107 |
+
final_states_out=None,
|
108 |
+
activation=None,
|
109 |
+
):
|
110 |
+
"""
|
111 |
+
x: (batch, dim, seqlen)
|
112 |
+
weight: (dim, width)
|
113 |
+
bias: (dim,)
|
114 |
+
seq_idx: (batch, seqlen)
|
115 |
+
initial_states: (batch, dim, width - 1)
|
116 |
+
final_states_out: (batch, dim, width - 1), to be written to
|
117 |
+
activation: either None or "silu" or "swish"
|
118 |
+
|
119 |
+
out: (batch, dim, seqlen)
|
120 |
+
"""
|
121 |
+
return CausalConv1dFn.apply(
|
122 |
+
x,
|
123 |
+
weight,
|
124 |
+
bias,
|
125 |
+
seq_idx,
|
126 |
+
initial_states,
|
127 |
+
return_final_states,
|
128 |
+
final_states_out,
|
129 |
+
activation,
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
def causal_conv1d_ref(
|
134 |
+
x,
|
135 |
+
weight,
|
136 |
+
bias=None,
|
137 |
+
initial_states=None,
|
138 |
+
return_final_states=False,
|
139 |
+
final_states_out=None,
|
140 |
+
activation=None,
|
141 |
+
):
|
142 |
+
"""
|
143 |
+
x: (batch, dim, seqlen)
|
144 |
+
weight: (dim, width)
|
145 |
+
bias: (dim,)
|
146 |
+
initial_states: (batch, dim, width - 1)
|
147 |
+
final_states_out: (batch, dim, width - 1)
|
148 |
+
|
149 |
+
out: (batch, dim, seqlen)
|
150 |
+
"""
|
151 |
+
if activation not in [None, "silu", "swish"]:
|
152 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
153 |
+
dtype_in = x.dtype
|
154 |
+
x = x.to(weight.dtype)
|
155 |
+
seqlen = x.shape[-1]
|
156 |
+
dim, width = weight.shape
|
157 |
+
if initial_states is None:
|
158 |
+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
159 |
+
else:
|
160 |
+
x = torch.cat([initial_states, x], dim=-1)
|
161 |
+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
162 |
+
out = out[..., :seqlen]
|
163 |
+
if return_final_states:
|
164 |
+
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
165 |
+
dtype_in
|
166 |
+
) # (batch, dim, width - 1)
|
167 |
+
if final_states_out is not None:
|
168 |
+
final_states_out.copy_(final_states)
|
169 |
+
else:
|
170 |
+
final_states_out = final_states
|
171 |
+
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
172 |
+
return out if not return_final_states else (out, final_states_out)
|
173 |
+
|
174 |
+
|
175 |
+
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
176 |
+
"""
|
177 |
+
x: (batch, dim) or (batch, dim, seqlen)
|
178 |
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
179 |
+
weight: (dim, width)
|
180 |
+
bias: (dim,)
|
181 |
+
cache_seqlens: (batch,), dtype int32.
|
182 |
+
If not None, the conv_state is treated as a circular buffer.
|
183 |
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
184 |
+
@cache_seqlens % state_len.
|
185 |
+
|
186 |
+
out: (batch, dim) or (batch, dim, seqlen)
|
187 |
+
"""
|
188 |
+
if activation not in [None, "silu", "swish"]:
|
189 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
190 |
+
activation = activation in ["silu", "swish"]
|
191 |
+
unsqueeze = x.dim() == 2
|
192 |
+
if unsqueeze:
|
193 |
+
x = x.unsqueeze(-1)
|
194 |
+
out = causal_conv1d_cuda.causal_conv1d_update(
|
195 |
+
x, conv_state, weight, bias, activation, cache_seqlens
|
196 |
+
)
|
197 |
+
if unsqueeze:
|
198 |
+
out = out.squeeze(-1)
|
199 |
+
return out
|
200 |
+
|
201 |
+
|
202 |
+
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
203 |
+
"""
|
204 |
+
x: (batch, dim) or (batch, dim, seqlen)
|
205 |
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
206 |
+
weight: (dim, width)
|
207 |
+
bias: (dim,)
|
208 |
+
cache_seqlens: (batch,), dtype int32.
|
209 |
+
If not None, the conv_state is treated as a circular buffer.
|
210 |
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
211 |
+
@cache_seqlens % state_len before performing the convolution.
|
212 |
+
|
213 |
+
out: (batch, dim) or (batch, dim, seqlen)
|
214 |
+
"""
|
215 |
+
if activation not in [None, "silu", "swish"]:
|
216 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
217 |
+
dtype_in = x.dtype
|
218 |
+
unsqueeze = x.dim() == 2
|
219 |
+
if unsqueeze:
|
220 |
+
x = x.unsqueeze(-1)
|
221 |
+
batch, dim, seqlen = x.shape
|
222 |
+
width = weight.shape[1]
|
223 |
+
state_len = conv_state.shape[-1]
|
224 |
+
assert conv_state.shape == (batch, dim, state_len)
|
225 |
+
assert weight.shape == (dim, width)
|
226 |
+
if cache_seqlens is None:
|
227 |
+
x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
|
228 |
+
conv_state.copy_(x_new[:, :, -state_len:])
|
229 |
+
else:
|
230 |
+
width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
231 |
+
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
232 |
+
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
|
233 |
+
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
234 |
+
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
235 |
+
conv_state.scatter_(2, copy_idx, x)
|
236 |
+
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
|
237 |
+
if unsqueeze:
|
238 |
+
out = out.squeeze(-1)
|
239 |
+
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
causal-conv1d/build/lib/causal_conv1d/causal_conv1d_varlen.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
|
8 |
+
@triton.jit
|
9 |
+
def _causal_conv1d_varlen_states(
|
10 |
+
X,
|
11 |
+
CU_SEQLENS,
|
12 |
+
STATES,
|
13 |
+
state_len,
|
14 |
+
dim,
|
15 |
+
stride_x_seqlen, stride_x_dim,
|
16 |
+
stride_states_batch, stride_states_seqlen, stride_states_dim,
|
17 |
+
BLOCK_M: tl.constexpr,
|
18 |
+
BLOCK_N: tl.constexpr
|
19 |
+
):
|
20 |
+
batch_idx = tl.program_id(2)
|
21 |
+
STATES += batch_idx * stride_states_batch
|
22 |
+
end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
|
23 |
+
start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
|
24 |
+
rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
25 |
+
cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
|
26 |
+
x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
|
27 |
+
mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
|
28 |
+
other=0)
|
29 |
+
rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
30 |
+
tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
|
31 |
+
x,
|
32 |
+
mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
|
33 |
+
|
34 |
+
|
35 |
+
def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
36 |
+
"""
|
37 |
+
Forward pass only, does not support backward pass.
|
38 |
+
Parameters:
|
39 |
+
x: (total_tokens, dim)
|
40 |
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
41 |
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
42 |
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
43 |
+
Return:
|
44 |
+
states: (batch, dim, state_len)
|
45 |
+
"""
|
46 |
+
_, dim = x.shape
|
47 |
+
batch = cu_seqlens.shape[0] - 1
|
48 |
+
cu_seqlens = cu_seqlens.contiguous()
|
49 |
+
states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
50 |
+
BLOCK_M = min(triton.next_power_of_2(state_len), 16)
|
51 |
+
BLOCK_N = min(triton.next_power_of_2(dim), 256)
|
52 |
+
grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
|
53 |
+
with torch.cuda.device(x.device.index):
|
54 |
+
_causal_conv1d_varlen_states[grid](
|
55 |
+
x,
|
56 |
+
cu_seqlens,
|
57 |
+
states,
|
58 |
+
state_len,
|
59 |
+
dim,
|
60 |
+
x.stride(0), x.stride(1),
|
61 |
+
states.stride(0), states.stride(2), states.stride(1),
|
62 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
|
63 |
+
)
|
64 |
+
return states
|
65 |
+
|
66 |
+
|
67 |
+
def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
68 |
+
"""
|
69 |
+
Forward pass only, does not support backward pass.
|
70 |
+
Parameters:
|
71 |
+
x: (total_tokens, dim)
|
72 |
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
73 |
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
74 |
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
75 |
+
Return:
|
76 |
+
states: (batch, dim, state_len)
|
77 |
+
"""
|
78 |
+
_, dim = x.shape
|
79 |
+
batch = cu_seqlens.shape[0] - 1
|
80 |
+
cu_seqlens = cu_seqlens.contiguous()
|
81 |
+
states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
82 |
+
for i in range(batch):
|
83 |
+
end_idx = cu_seqlens[i + 1]
|
84 |
+
start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
|
85 |
+
states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
|
86 |
+
return states
|
causal-conv1d/causal_conv1d.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: causal-conv1d
|
3 |
+
Version: 1.4.0
|
4 |
+
Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
|
5 |
+
Home-page: https://github.com/Dao-AILab/causal-conv1d
|
6 |
+
Author: Tri Dao
|
7 |
+
Author-email: [email protected]
|
8 |
+
License: UNKNOWN
|
9 |
+
Platform: UNKNOWN
|
10 |
+
Classifier: Programming Language :: Python :: 3
|
11 |
+
Classifier: License :: OSI Approved :: BSD License
|
12 |
+
Classifier: Operating System :: Unix
|
13 |
+
Requires-Python: >=3.8
|
14 |
+
Description-Content-Type: text/markdown
|
15 |
+
License-File: LICENSE
|
16 |
+
License-File: AUTHORS
|
17 |
+
|
18 |
+
# Causal depthwise conv1d in CUDA with a PyTorch interface
|
19 |
+
|
20 |
+
Features:
|
21 |
+
- Support fp32, fp16, bf16.
|
22 |
+
- Kernel size 2, 3, 4.
|
23 |
+
|
24 |
+
## How to use
|
25 |
+
|
26 |
+
```
|
27 |
+
from causal_conv1d import causal_conv1d_fn
|
28 |
+
```
|
29 |
+
|
30 |
+
```
|
31 |
+
def causal_conv1d_fn(x, weight, bias=None, activation=None):
|
32 |
+
"""
|
33 |
+
x: (batch, dim, seqlen)
|
34 |
+
weight: (dim, width)
|
35 |
+
bias: (dim,)
|
36 |
+
activation: either None or "silu" or "swish"
|
37 |
+
|
38 |
+
out: (batch, dim, seqlen)
|
39 |
+
"""
|
40 |
+
```
|
41 |
+
|
42 |
+
Equivalent to:
|
43 |
+
```
|
44 |
+
import torch.nn.functional as F
|
45 |
+
|
46 |
+
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
47 |
+
```
|
48 |
+
|
49 |
+
## Additional Prerequisites for AMD cards
|
50 |
+
|
51 |
+
### Patching ROCm
|
52 |
+
|
53 |
+
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
|
54 |
+
|
55 |
+
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
|
56 |
+
|
57 |
+
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
|
58 |
+
```bash
|
59 |
+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
|
60 |
+
```
|
61 |
+
|
62 |
+
|
causal-conv1d/causal_conv1d.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AUTHORS
|
2 |
+
LICENSE
|
3 |
+
README.md
|
4 |
+
setup.py
|
5 |
+
causal_conv1d/__init__.py
|
6 |
+
causal_conv1d/causal_conv1d_interface.py
|
7 |
+
causal_conv1d/causal_conv1d_varlen.py
|
8 |
+
causal_conv1d.egg-info/PKG-INFO
|
9 |
+
causal_conv1d.egg-info/SOURCES.txt
|
10 |
+
causal_conv1d.egg-info/dependency_links.txt
|
11 |
+
causal_conv1d.egg-info/requires.txt
|
12 |
+
causal_conv1d.egg-info/top_level.txt
|
causal-conv1d/causal_conv1d.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
causal-conv1d/causal_conv1d.egg-info/requires.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
packaging
|
3 |
+
ninja
|
causal-conv1d/causal_conv1d.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
causal_conv1d
|
causal-conv1d/causal_conv1d/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "1.4.0"
|
2 |
+
|
3 |
+
from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
|
causal-conv1d/causal_conv1d/causal_conv1d_interface.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
import causal_conv1d_cuda
|
8 |
+
|
9 |
+
|
10 |
+
class CausalConv1dFn(torch.autograd.Function):
|
11 |
+
@staticmethod
|
12 |
+
def forward(
|
13 |
+
ctx,
|
14 |
+
x,
|
15 |
+
weight,
|
16 |
+
bias=None,
|
17 |
+
seq_idx=None,
|
18 |
+
initial_states=None,
|
19 |
+
return_final_states=False,
|
20 |
+
final_states_out=None,
|
21 |
+
activation=None,
|
22 |
+
):
|
23 |
+
if activation not in [None, "silu", "swish"]:
|
24 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
25 |
+
if x.stride(2) != 1 and x.stride(1) != 1:
|
26 |
+
x = x.contiguous()
|
27 |
+
bias = bias.contiguous() if bias is not None else None
|
28 |
+
if seq_idx is not None:
|
29 |
+
assert (
|
30 |
+
initial_states is None
|
31 |
+
), "initial_states must be None if seq_idx is not None"
|
32 |
+
assert (
|
33 |
+
not return_final_states
|
34 |
+
), "If seq_idx is not None, we don't return final_states_out"
|
35 |
+
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
36 |
+
if initial_states is not None and (
|
37 |
+
initial_states.stride(2) != 1 and initial_states.stride(1) != 1
|
38 |
+
):
|
39 |
+
initial_states = initial_states.contiguous()
|
40 |
+
if return_final_states:
|
41 |
+
assert (
|
42 |
+
x.stride(1) == 1
|
43 |
+
), "Only channel-last layout support returning final_states_out"
|
44 |
+
if final_states_out is not None:
|
45 |
+
assert (
|
46 |
+
final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
batch, dim, seqlen = x.shape
|
50 |
+
width = weight.shape[1]
|
51 |
+
final_states_out = torch.empty(
|
52 |
+
batch, width - 1, dim, device=x.device, dtype=x.dtype
|
53 |
+
).transpose(1, 2)
|
54 |
+
else:
|
55 |
+
final_states_out = None
|
56 |
+
ctx.activation = activation in ["silu", "swish"]
|
57 |
+
out = causal_conv1d_cuda.causal_conv1d_fwd(
|
58 |
+
x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
|
59 |
+
)
|
60 |
+
ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
|
61 |
+
ctx.return_final_states = return_final_states
|
62 |
+
ctx.return_dinitial_states = (
|
63 |
+
initial_states is not None and initial_states.requires_grad
|
64 |
+
)
|
65 |
+
return out if not return_final_states else (out, final_states_out)
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def backward(ctx, dout, *args):
|
69 |
+
x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
|
70 |
+
dfinal_states = args[0] if ctx.return_final_states else None
|
71 |
+
if dout.stride(2) != 1 and dout.stride(1) != 1:
|
72 |
+
dout = dout.contiguous()
|
73 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
74 |
+
# backward of conv1d with the backward of chunk).
|
75 |
+
# Here we just pass in None and dx will be allocated in the C++ code.
|
76 |
+
dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd(
|
77 |
+
x,
|
78 |
+
weight,
|
79 |
+
bias,
|
80 |
+
dout,
|
81 |
+
seq_idx,
|
82 |
+
initial_states,
|
83 |
+
dfinal_states,
|
84 |
+
None,
|
85 |
+
ctx.return_dinitial_states,
|
86 |
+
ctx.activation,
|
87 |
+
)
|
88 |
+
return (
|
89 |
+
dx,
|
90 |
+
dweight,
|
91 |
+
dbias if bias is not None else None,
|
92 |
+
None,
|
93 |
+
dinitial_states if initial_states is not None else None,
|
94 |
+
None,
|
95 |
+
None,
|
96 |
+
None,
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
def causal_conv1d_fn(
|
101 |
+
x,
|
102 |
+
weight,
|
103 |
+
bias=None,
|
104 |
+
seq_idx=None,
|
105 |
+
initial_states=None,
|
106 |
+
return_final_states=False,
|
107 |
+
final_states_out=None,
|
108 |
+
activation=None,
|
109 |
+
):
|
110 |
+
"""
|
111 |
+
x: (batch, dim, seqlen)
|
112 |
+
weight: (dim, width)
|
113 |
+
bias: (dim,)
|
114 |
+
seq_idx: (batch, seqlen)
|
115 |
+
initial_states: (batch, dim, width - 1)
|
116 |
+
final_states_out: (batch, dim, width - 1), to be written to
|
117 |
+
activation: either None or "silu" or "swish"
|
118 |
+
|
119 |
+
out: (batch, dim, seqlen)
|
120 |
+
"""
|
121 |
+
return CausalConv1dFn.apply(
|
122 |
+
x,
|
123 |
+
weight,
|
124 |
+
bias,
|
125 |
+
seq_idx,
|
126 |
+
initial_states,
|
127 |
+
return_final_states,
|
128 |
+
final_states_out,
|
129 |
+
activation,
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
def causal_conv1d_ref(
|
134 |
+
x,
|
135 |
+
weight,
|
136 |
+
bias=None,
|
137 |
+
initial_states=None,
|
138 |
+
return_final_states=False,
|
139 |
+
final_states_out=None,
|
140 |
+
activation=None,
|
141 |
+
):
|
142 |
+
"""
|
143 |
+
x: (batch, dim, seqlen)
|
144 |
+
weight: (dim, width)
|
145 |
+
bias: (dim,)
|
146 |
+
initial_states: (batch, dim, width - 1)
|
147 |
+
final_states_out: (batch, dim, width - 1)
|
148 |
+
|
149 |
+
out: (batch, dim, seqlen)
|
150 |
+
"""
|
151 |
+
if activation not in [None, "silu", "swish"]:
|
152 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
153 |
+
dtype_in = x.dtype
|
154 |
+
x = x.to(weight.dtype)
|
155 |
+
seqlen = x.shape[-1]
|
156 |
+
dim, width = weight.shape
|
157 |
+
if initial_states is None:
|
158 |
+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
159 |
+
else:
|
160 |
+
x = torch.cat([initial_states, x], dim=-1)
|
161 |
+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
162 |
+
out = out[..., :seqlen]
|
163 |
+
if return_final_states:
|
164 |
+
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
165 |
+
dtype_in
|
166 |
+
) # (batch, dim, width - 1)
|
167 |
+
if final_states_out is not None:
|
168 |
+
final_states_out.copy_(final_states)
|
169 |
+
else:
|
170 |
+
final_states_out = final_states
|
171 |
+
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
172 |
+
return out if not return_final_states else (out, final_states_out)
|
173 |
+
|
174 |
+
|
175 |
+
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
176 |
+
"""
|
177 |
+
x: (batch, dim) or (batch, dim, seqlen)
|
178 |
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
179 |
+
weight: (dim, width)
|
180 |
+
bias: (dim,)
|
181 |
+
cache_seqlens: (batch,), dtype int32.
|
182 |
+
If not None, the conv_state is treated as a circular buffer.
|
183 |
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
184 |
+
@cache_seqlens % state_len.
|
185 |
+
|
186 |
+
out: (batch, dim) or (batch, dim, seqlen)
|
187 |
+
"""
|
188 |
+
if activation not in [None, "silu", "swish"]:
|
189 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
190 |
+
activation = activation in ["silu", "swish"]
|
191 |
+
unsqueeze = x.dim() == 2
|
192 |
+
if unsqueeze:
|
193 |
+
x = x.unsqueeze(-1)
|
194 |
+
out = causal_conv1d_cuda.causal_conv1d_update(
|
195 |
+
x, conv_state, weight, bias, activation, cache_seqlens
|
196 |
+
)
|
197 |
+
if unsqueeze:
|
198 |
+
out = out.squeeze(-1)
|
199 |
+
return out
|
200 |
+
|
201 |
+
|
202 |
+
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
203 |
+
"""
|
204 |
+
x: (batch, dim) or (batch, dim, seqlen)
|
205 |
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
206 |
+
weight: (dim, width)
|
207 |
+
bias: (dim,)
|
208 |
+
cache_seqlens: (batch,), dtype int32.
|
209 |
+
If not None, the conv_state is treated as a circular buffer.
|
210 |
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
211 |
+
@cache_seqlens % state_len before performing the convolution.
|
212 |
+
|
213 |
+
out: (batch, dim) or (batch, dim, seqlen)
|
214 |
+
"""
|
215 |
+
if activation not in [None, "silu", "swish"]:
|
216 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
217 |
+
dtype_in = x.dtype
|
218 |
+
unsqueeze = x.dim() == 2
|
219 |
+
if unsqueeze:
|
220 |
+
x = x.unsqueeze(-1)
|
221 |
+
batch, dim, seqlen = x.shape
|
222 |
+
width = weight.shape[1]
|
223 |
+
state_len = conv_state.shape[-1]
|
224 |
+
assert conv_state.shape == (batch, dim, state_len)
|
225 |
+
assert weight.shape == (dim, width)
|
226 |
+
if cache_seqlens is None:
|
227 |
+
x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
|
228 |
+
conv_state.copy_(x_new[:, :, -state_len:])
|
229 |
+
else:
|
230 |
+
width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
231 |
+
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
232 |
+
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
|
233 |
+
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
234 |
+
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
235 |
+
conv_state.scatter_(2, copy_idx, x)
|
236 |
+
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
|
237 |
+
if unsqueeze:
|
238 |
+
out = out.squeeze(-1)
|
239 |
+
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
causal-conv1d/causal_conv1d/causal_conv1d_varlen.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
|
8 |
+
@triton.jit
|
9 |
+
def _causal_conv1d_varlen_states(
|
10 |
+
X,
|
11 |
+
CU_SEQLENS,
|
12 |
+
STATES,
|
13 |
+
state_len,
|
14 |
+
dim,
|
15 |
+
stride_x_seqlen, stride_x_dim,
|
16 |
+
stride_states_batch, stride_states_seqlen, stride_states_dim,
|
17 |
+
BLOCK_M: tl.constexpr,
|
18 |
+
BLOCK_N: tl.constexpr
|
19 |
+
):
|
20 |
+
batch_idx = tl.program_id(2)
|
21 |
+
STATES += batch_idx * stride_states_batch
|
22 |
+
end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
|
23 |
+
start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
|
24 |
+
rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
25 |
+
cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
|
26 |
+
x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
|
27 |
+
mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
|
28 |
+
other=0)
|
29 |
+
rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
30 |
+
tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
|
31 |
+
x,
|
32 |
+
mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
|
33 |
+
|
34 |
+
|
35 |
+
def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
36 |
+
"""
|
37 |
+
Forward pass only, does not support backward pass.
|
38 |
+
Parameters:
|
39 |
+
x: (total_tokens, dim)
|
40 |
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
41 |
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
42 |
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
43 |
+
Return:
|
44 |
+
states: (batch, dim, state_len)
|
45 |
+
"""
|
46 |
+
_, dim = x.shape
|
47 |
+
batch = cu_seqlens.shape[0] - 1
|
48 |
+
cu_seqlens = cu_seqlens.contiguous()
|
49 |
+
states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
50 |
+
BLOCK_M = min(triton.next_power_of_2(state_len), 16)
|
51 |
+
BLOCK_N = min(triton.next_power_of_2(dim), 256)
|
52 |
+
grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
|
53 |
+
with torch.cuda.device(x.device.index):
|
54 |
+
_causal_conv1d_varlen_states[grid](
|
55 |
+
x,
|
56 |
+
cu_seqlens,
|
57 |
+
states,
|
58 |
+
state_len,
|
59 |
+
dim,
|
60 |
+
x.stride(0), x.stride(1),
|
61 |
+
states.stride(0), states.stride(2), states.stride(1),
|
62 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
|
63 |
+
)
|
64 |
+
return states
|
65 |
+
|
66 |
+
|
67 |
+
def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
68 |
+
"""
|
69 |
+
Forward pass only, does not support backward pass.
|
70 |
+
Parameters:
|
71 |
+
x: (total_tokens, dim)
|
72 |
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
73 |
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
74 |
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
75 |
+
Return:
|
76 |
+
states: (batch, dim, state_len)
|
77 |
+
"""
|
78 |
+
_, dim = x.shape
|
79 |
+
batch = cu_seqlens.shape[0] - 1
|
80 |
+
cu_seqlens = cu_seqlens.contiguous()
|
81 |
+
states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
82 |
+
for i in range(batch):
|
83 |
+
end_idx = cu_seqlens[i + 1]
|
84 |
+
start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
|
85 |
+
states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
|
86 |
+
return states
|
causal-conv1d/csrc/causal_conv1d.cpp
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <c10/cuda/CUDAGuard.h>
|
7 |
+
#include <torch/extension.h>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
#include "causal_conv1d.h"
|
11 |
+
|
12 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
13 |
+
|
14 |
+
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
15 |
+
if (ITYPE == at::ScalarType::Half) { \
|
16 |
+
using input_t = at::Half; \
|
17 |
+
__VA_ARGS__(); \
|
18 |
+
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
19 |
+
using input_t = at::BFloat16; \
|
20 |
+
__VA_ARGS__(); \
|
21 |
+
} else if (ITYPE == at::ScalarType::Float) { \
|
22 |
+
using input_t = float; \
|
23 |
+
__VA_ARGS__(); \
|
24 |
+
} else { \
|
25 |
+
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
26 |
+
}
|
27 |
+
|
28 |
+
#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
|
29 |
+
if (WTYPE == at::ScalarType::Half) { \
|
30 |
+
using weight_t = at::Half; \
|
31 |
+
__VA_ARGS__(); \
|
32 |
+
} else if (WTYPE == at::ScalarType::BFloat16) { \
|
33 |
+
using weight_t = at::BFloat16; \
|
34 |
+
__VA_ARGS__(); \
|
35 |
+
} else if (WTYPE == at::ScalarType::Float) { \
|
36 |
+
using weight_t = float; \
|
37 |
+
__VA_ARGS__(); \
|
38 |
+
} else { \
|
39 |
+
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
40 |
+
}
|
41 |
+
|
42 |
+
template<typename input_t, typename weight_t>
|
43 |
+
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
44 |
+
template <typename input_t, typename weight_t>
|
45 |
+
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
46 |
+
|
47 |
+
template<typename input_t, typename weight_t>
|
48 |
+
void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
|
49 |
+
template<typename input_t, typename weight_t>
|
50 |
+
void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
|
51 |
+
|
52 |
+
template<typename input_t, typename weight_t>
|
53 |
+
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
54 |
+
|
55 |
+
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
56 |
+
// sizes
|
57 |
+
const size_t batch,
|
58 |
+
const size_t dim,
|
59 |
+
const size_t seqlen,
|
60 |
+
const size_t width,
|
61 |
+
// device pointers
|
62 |
+
const at::Tensor x,
|
63 |
+
const at::Tensor weight,
|
64 |
+
const at::Tensor out,
|
65 |
+
void* bias_ptr,
|
66 |
+
bool silu_activation) {
|
67 |
+
|
68 |
+
// Reset the parameters
|
69 |
+
memset(¶ms, 0, sizeof(params));
|
70 |
+
|
71 |
+
params.batch = batch;
|
72 |
+
params.dim = dim;
|
73 |
+
params.seqlen = seqlen;
|
74 |
+
params.width = width;
|
75 |
+
|
76 |
+
params.silu_activation = silu_activation;
|
77 |
+
|
78 |
+
// Set the pointers and strides.
|
79 |
+
params.x_ptr = x.data_ptr();
|
80 |
+
params.weight_ptr = weight.data_ptr();
|
81 |
+
params.bias_ptr = bias_ptr;
|
82 |
+
params.out_ptr = out.data_ptr();
|
83 |
+
// All stride are in elements, not bytes.
|
84 |
+
params.x_batch_stride = x.stride(0);
|
85 |
+
params.x_c_stride = x.stride(1);
|
86 |
+
params.x_l_stride = x.stride(-1);
|
87 |
+
params.weight_c_stride = weight.stride(0);
|
88 |
+
params.weight_width_stride = weight.stride(1);
|
89 |
+
params.out_batch_stride = out.stride(0);
|
90 |
+
params.out_c_stride = out.stride(1);
|
91 |
+
params.out_l_stride = out.stride(-1);
|
92 |
+
}
|
93 |
+
|
94 |
+
|
95 |
+
void set_conv_params_bwd(ConvParamsBwd ¶ms,
|
96 |
+
// sizes
|
97 |
+
const size_t batch,
|
98 |
+
const size_t dim,
|
99 |
+
const size_t seqlen,
|
100 |
+
const size_t width,
|
101 |
+
// device pointers
|
102 |
+
const at::Tensor x,
|
103 |
+
const at::Tensor weight,
|
104 |
+
void* bias_ptr,
|
105 |
+
const at::Tensor dout,
|
106 |
+
const at::Tensor dx,
|
107 |
+
const at::Tensor dweight,
|
108 |
+
void* dbias_ptr,
|
109 |
+
bool silu_activation) {
|
110 |
+
// Pass in "dout" instead of "out", we're not gonna use "out" at all.
|
111 |
+
set_conv_params_fwd(params, batch, dim, seqlen, width,
|
112 |
+
x, weight, dout, bias_ptr, silu_activation);
|
113 |
+
|
114 |
+
// Set the pointers and strides.
|
115 |
+
params.dout_ptr = dout.data_ptr();
|
116 |
+
params.dx_ptr = dx.data_ptr();
|
117 |
+
params.dweight_ptr = dweight.data_ptr();
|
118 |
+
params.dbias_ptr = dbias_ptr;
|
119 |
+
// All stride are in elements, not bytes.
|
120 |
+
params.dout_batch_stride = dout.stride(0);
|
121 |
+
params.dout_c_stride = dout.stride(1);
|
122 |
+
params.dout_l_stride = dout.stride(2);
|
123 |
+
params.dweight_c_stride = dweight.stride(0);
|
124 |
+
params.dweight_width_stride = dweight.stride(1);
|
125 |
+
params.dx_batch_stride = dx.stride(0);
|
126 |
+
params.dx_c_stride = dx.stride(1);
|
127 |
+
params.dx_l_stride = dx.stride(2);
|
128 |
+
}
|
129 |
+
|
130 |
+
at::Tensor
|
131 |
+
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
132 |
+
const c10::optional<at::Tensor> &bias_,
|
133 |
+
const c10::optional<at::Tensor> &seq_idx_,
|
134 |
+
const c10::optional<at::Tensor> &initial_states_,
|
135 |
+
c10::optional<at::Tensor> &final_states_out_,
|
136 |
+
bool silu_activation) {
|
137 |
+
auto input_type = x.scalar_type();
|
138 |
+
auto weight_type = weight.scalar_type();
|
139 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
140 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
141 |
+
|
142 |
+
TORCH_CHECK(x.is_cuda());
|
143 |
+
TORCH_CHECK(weight.is_cuda());
|
144 |
+
|
145 |
+
const auto sizes = x.sizes();
|
146 |
+
const int batch_size = sizes[0];
|
147 |
+
const int dim = sizes[1];
|
148 |
+
const int seqlen = sizes[2];
|
149 |
+
const int width = weight.size(-1);
|
150 |
+
|
151 |
+
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
152 |
+
CHECK_SHAPE(weight, dim, width);
|
153 |
+
|
154 |
+
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
155 |
+
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
156 |
+
|
157 |
+
if (is_channel_last) {
|
158 |
+
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
159 |
+
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
160 |
+
}
|
161 |
+
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
162 |
+
|
163 |
+
if (bias_.has_value()) {
|
164 |
+
auto bias = bias_.value();
|
165 |
+
TORCH_CHECK(bias.scalar_type() == weight_type);
|
166 |
+
TORCH_CHECK(bias.is_cuda());
|
167 |
+
TORCH_CHECK(bias.stride(-1) == 1);
|
168 |
+
CHECK_SHAPE(bias, dim);
|
169 |
+
}
|
170 |
+
|
171 |
+
if (seq_idx_.has_value()) {
|
172 |
+
TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
|
173 |
+
auto seq_idx = seq_idx_.value();
|
174 |
+
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
175 |
+
TORCH_CHECK(seq_idx.is_cuda());
|
176 |
+
TORCH_CHECK(seq_idx.is_contiguous());
|
177 |
+
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
178 |
+
}
|
179 |
+
|
180 |
+
at::Tensor out = torch::empty_like(x);
|
181 |
+
|
182 |
+
ConvParamsBase params;
|
183 |
+
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
184 |
+
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
185 |
+
silu_activation);
|
186 |
+
|
187 |
+
if (seq_idx_.has_value()) {
|
188 |
+
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
189 |
+
} else {
|
190 |
+
params.seq_idx_ptr = nullptr;
|
191 |
+
}
|
192 |
+
|
193 |
+
if (initial_states_.has_value()) {
|
194 |
+
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
195 |
+
auto initial_states = initial_states_.value();
|
196 |
+
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
197 |
+
TORCH_CHECK(initial_states.is_cuda());
|
198 |
+
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
199 |
+
TORCH_CHECK(initial_states.stride(1) == 1);
|
200 |
+
params.initial_states_ptr = initial_states.data_ptr();
|
201 |
+
params.initial_states_batch_stride = initial_states.stride(0);
|
202 |
+
params.initial_states_c_stride = initial_states.stride(1);
|
203 |
+
params.initial_states_l_stride = initial_states.stride(2);
|
204 |
+
} else {
|
205 |
+
params.initial_states_ptr = nullptr;
|
206 |
+
}
|
207 |
+
|
208 |
+
if (final_states_out_.has_value()) {
|
209 |
+
TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
|
210 |
+
auto final_states = final_states_out_.value();
|
211 |
+
TORCH_CHECK(final_states.scalar_type() == input_type);
|
212 |
+
TORCH_CHECK(final_states.is_cuda());
|
213 |
+
CHECK_SHAPE(final_states, batch_size, dim, width - 1);
|
214 |
+
TORCH_CHECK(final_states.stride(1) == 1);
|
215 |
+
params.final_states_ptr = final_states.data_ptr();
|
216 |
+
params.final_states_batch_stride = final_states.stride(0);
|
217 |
+
params.final_states_c_stride = final_states.stride(1);
|
218 |
+
params.final_states_l_stride = final_states.stride(2);
|
219 |
+
} else {
|
220 |
+
params.final_states_ptr = nullptr;
|
221 |
+
}
|
222 |
+
|
223 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
224 |
+
// Cast to char to avoid compiler warning about narrowing
|
225 |
+
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
226 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
227 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
228 |
+
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
|
229 |
+
if (!is_channel_last) {
|
230 |
+
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
231 |
+
} else {
|
232 |
+
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
|
233 |
+
}
|
234 |
+
});
|
235 |
+
});
|
236 |
+
return out;
|
237 |
+
}
|
238 |
+
|
239 |
+
std::vector<at::Tensor>
|
240 |
+
causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight,
|
241 |
+
const c10::optional<at::Tensor> &bias_,
|
242 |
+
at::Tensor &dout,
|
243 |
+
const c10::optional<at::Tensor> &seq_idx_,
|
244 |
+
const c10::optional<at::Tensor> &initial_states_,
|
245 |
+
const c10::optional<at::Tensor> &dfinal_states_,
|
246 |
+
c10::optional<at::Tensor> &dx_,
|
247 |
+
bool return_dinitial_states,
|
248 |
+
bool silu_activation) {
|
249 |
+
auto input_type = x.scalar_type();
|
250 |
+
auto weight_type = weight.scalar_type();
|
251 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
252 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
253 |
+
|
254 |
+
TORCH_CHECK(x.is_cuda());
|
255 |
+
TORCH_CHECK(weight.is_cuda());
|
256 |
+
TORCH_CHECK(dout.is_cuda());
|
257 |
+
|
258 |
+
const auto sizes = x.sizes();
|
259 |
+
const int batch_size = sizes[0];
|
260 |
+
const int dim = sizes[1];
|
261 |
+
const int seqlen = sizes[2];
|
262 |
+
const int width = weight.size(-1);
|
263 |
+
|
264 |
+
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
265 |
+
|
266 |
+
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
267 |
+
CHECK_SHAPE(weight, dim, width);
|
268 |
+
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
269 |
+
|
270 |
+
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
271 |
+
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
272 |
+
if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
|
273 |
+
if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
|
274 |
+
|
275 |
+
if (is_channel_last) {
|
276 |
+
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
277 |
+
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
278 |
+
TORCH_CHECK(dout.stride(2) % 8 == 0 and dout.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (dout.stride(0) and dout.stride(2)) to be multiples of 8");
|
279 |
+
}
|
280 |
+
|
281 |
+
if (bias_.has_value()) {
|
282 |
+
auto bias = bias_.value();
|
283 |
+
TORCH_CHECK(bias.scalar_type() == weight_type);
|
284 |
+
TORCH_CHECK(bias.is_cuda());
|
285 |
+
TORCH_CHECK(bias.stride(-1) == 1);
|
286 |
+
CHECK_SHAPE(bias, dim);
|
287 |
+
}
|
288 |
+
|
289 |
+
if (seq_idx_.has_value()) {
|
290 |
+
TORCH_CHECK(is_channel_last, "seq_idx only supported for channel last layout");
|
291 |
+
auto seq_idx = seq_idx_.value();
|
292 |
+
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
293 |
+
TORCH_CHECK(seq_idx.is_cuda());
|
294 |
+
TORCH_CHECK(seq_idx.is_contiguous());
|
295 |
+
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
296 |
+
}
|
297 |
+
|
298 |
+
at::Tensor dx;
|
299 |
+
if (dx_.has_value()) {
|
300 |
+
dx = dx_.value();
|
301 |
+
TORCH_CHECK(dx.scalar_type() == input_type);
|
302 |
+
TORCH_CHECK(dx.is_cuda());
|
303 |
+
CHECK_SHAPE(dx, batch_size, dim, seqlen);
|
304 |
+
if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
|
305 |
+
if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
|
306 |
+
} else {
|
307 |
+
dx = torch::empty_like(x);
|
308 |
+
}
|
309 |
+
|
310 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
311 |
+
// Cast to char to avoid compiler warning about narrowing
|
312 |
+
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
313 |
+
|
314 |
+
at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat));
|
315 |
+
at::Tensor dbias;
|
316 |
+
if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); }
|
317 |
+
|
318 |
+
ConvParamsBwd params;
|
319 |
+
set_conv_params_bwd(params, batch_size, dim, seqlen, width,
|
320 |
+
x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
321 |
+
dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr,
|
322 |
+
silu_activation);
|
323 |
+
|
324 |
+
if (seq_idx_.has_value()) {
|
325 |
+
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
326 |
+
} else {
|
327 |
+
params.seq_idx_ptr = nullptr;
|
328 |
+
}
|
329 |
+
|
330 |
+
if (initial_states_.has_value()) {
|
331 |
+
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
332 |
+
auto initial_states = initial_states_.value();
|
333 |
+
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
334 |
+
TORCH_CHECK(initial_states.is_cuda());
|
335 |
+
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
336 |
+
TORCH_CHECK(initial_states.stride(1) == 1);
|
337 |
+
params.initial_states_ptr = initial_states.data_ptr();
|
338 |
+
params.initial_states_batch_stride = initial_states.stride(0);
|
339 |
+
params.initial_states_c_stride = initial_states.stride(1);
|
340 |
+
params.initial_states_l_stride = initial_states.stride(2);
|
341 |
+
} else {
|
342 |
+
params.initial_states_ptr = nullptr;
|
343 |
+
}
|
344 |
+
|
345 |
+
if (dfinal_states_.has_value()) {
|
346 |
+
TORCH_CHECK(is_channel_last, "dfinal_states is only supported for channel last layout");
|
347 |
+
auto dfinal_states = dfinal_states_.value();
|
348 |
+
TORCH_CHECK(dfinal_states.scalar_type() == input_type);
|
349 |
+
TORCH_CHECK(dfinal_states.is_cuda());
|
350 |
+
CHECK_SHAPE(dfinal_states, batch_size, dim, width - 1);
|
351 |
+
params.dfinal_states_ptr = dfinal_states.data_ptr();
|
352 |
+
params.dfinal_states_batch_stride = dfinal_states.stride(0);
|
353 |
+
params.dfinal_states_c_stride = dfinal_states.stride(1);
|
354 |
+
params.dfinal_states_l_stride = dfinal_states.stride(2);
|
355 |
+
} else {
|
356 |
+
params.dfinal_states_ptr = nullptr;
|
357 |
+
}
|
358 |
+
|
359 |
+
at::Tensor dinitial_states;
|
360 |
+
if (return_dinitial_states) {
|
361 |
+
dinitial_states = torch::empty({batch_size, width - 1, dim}, x.options()).transpose(1, 2);
|
362 |
+
TORCH_CHECK(dinitial_states.stride(1) == 1);
|
363 |
+
params.dinitial_states_ptr = dinitial_states.data_ptr();
|
364 |
+
params.dinitial_states_batch_stride = dinitial_states.stride(0);
|
365 |
+
params.dinitial_states_c_stride = dinitial_states.stride(1);
|
366 |
+
params.dinitial_states_l_stride = dinitial_states.stride(2);
|
367 |
+
} else {
|
368 |
+
params.dinitial_states_ptr = nullptr;
|
369 |
+
}
|
370 |
+
|
371 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
372 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
|
373 |
+
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
|
374 |
+
if (!is_channel_last) {
|
375 |
+
causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream);
|
376 |
+
} else {
|
377 |
+
causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream);
|
378 |
+
}
|
379 |
+
});
|
380 |
+
});
|
381 |
+
return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias, dinitial_states};
|
382 |
+
}
|
383 |
+
|
384 |
+
at::Tensor
|
385 |
+
causal_conv1d_update(const at::Tensor &x,
|
386 |
+
const at::Tensor &conv_state,
|
387 |
+
const at::Tensor &weight,
|
388 |
+
const c10::optional<at::Tensor> &bias_,
|
389 |
+
bool silu_activation,
|
390 |
+
const c10::optional<at::Tensor> &cache_seqlens_
|
391 |
+
) {
|
392 |
+
auto input_type = x.scalar_type();
|
393 |
+
auto weight_type = weight.scalar_type();
|
394 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
395 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
396 |
+
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
397 |
+
|
398 |
+
TORCH_CHECK(x.is_cuda());
|
399 |
+
TORCH_CHECK(conv_state.is_cuda());
|
400 |
+
TORCH_CHECK(weight.is_cuda());
|
401 |
+
|
402 |
+
const auto sizes = x.sizes();
|
403 |
+
const int batch_size = sizes[0];
|
404 |
+
const int dim = sizes[1];
|
405 |
+
const int seqlen = sizes[2];
|
406 |
+
const int width = weight.size(-1);
|
407 |
+
const int conv_state_len = conv_state.size(2);
|
408 |
+
TORCH_CHECK(conv_state_len >= width - 1);
|
409 |
+
|
410 |
+
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
411 |
+
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
412 |
+
CHECK_SHAPE(weight, dim, width);
|
413 |
+
|
414 |
+
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
415 |
+
|
416 |
+
if (bias_.has_value()) {
|
417 |
+
auto bias = bias_.value();
|
418 |
+
TORCH_CHECK(bias.scalar_type() == weight_type);
|
419 |
+
TORCH_CHECK(bias.is_cuda());
|
420 |
+
TORCH_CHECK(bias.stride(-1) == 1);
|
421 |
+
CHECK_SHAPE(bias, dim);
|
422 |
+
}
|
423 |
+
|
424 |
+
at::Tensor out = torch::empty_like(x);
|
425 |
+
|
426 |
+
ConvParamsBase params;
|
427 |
+
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
428 |
+
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
429 |
+
silu_activation);
|
430 |
+
params.conv_state_ptr = conv_state.data_ptr();
|
431 |
+
params.conv_state_len = conv_state_len;
|
432 |
+
// All stride are in elements, not bytes.
|
433 |
+
params.conv_state_batch_stride = conv_state.stride(0);
|
434 |
+
params.conv_state_c_stride = conv_state.stride(1);
|
435 |
+
params.conv_state_l_stride = conv_state.stride(2);
|
436 |
+
|
437 |
+
if (cache_seqlens_.has_value()) {
|
438 |
+
auto cache_seqlens = cache_seqlens_.value();
|
439 |
+
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
440 |
+
TORCH_CHECK(cache_seqlens.is_cuda());
|
441 |
+
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
442 |
+
CHECK_SHAPE(cache_seqlens, batch_size);
|
443 |
+
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
444 |
+
} else {
|
445 |
+
params.cache_seqlens = nullptr;
|
446 |
+
}
|
447 |
+
|
448 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
449 |
+
// Cast to char to avoid compiler warning about narrowing
|
450 |
+
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
451 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
452 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
453 |
+
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
|
454 |
+
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
455 |
+
});
|
456 |
+
});
|
457 |
+
return out;
|
458 |
+
}
|
459 |
+
|
460 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
461 |
+
m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
|
462 |
+
m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
|
463 |
+
m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
|
464 |
+
}
|
causal-conv1d/csrc/causal_conv1d.h
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
8 |
+
|
9 |
+
struct ConvParamsBase {
|
10 |
+
using index_t = uint32_t;
|
11 |
+
|
12 |
+
int batch, dim, seqlen, width;
|
13 |
+
bool silu_activation;
|
14 |
+
|
15 |
+
index_t x_batch_stride;
|
16 |
+
index_t x_c_stride;
|
17 |
+
index_t x_l_stride;
|
18 |
+
index_t weight_c_stride;
|
19 |
+
index_t weight_width_stride;
|
20 |
+
index_t out_batch_stride;
|
21 |
+
index_t out_c_stride;
|
22 |
+
index_t out_l_stride;
|
23 |
+
|
24 |
+
int conv_state_len;
|
25 |
+
index_t conv_state_batch_stride;
|
26 |
+
index_t conv_state_c_stride;
|
27 |
+
index_t conv_state_l_stride;
|
28 |
+
|
29 |
+
// Common data pointers.
|
30 |
+
void *__restrict__ x_ptr;
|
31 |
+
void *__restrict__ weight_ptr;
|
32 |
+
void *__restrict__ bias_ptr;
|
33 |
+
void *__restrict__ out_ptr;
|
34 |
+
|
35 |
+
void *__restrict__ conv_state_ptr;
|
36 |
+
int32_t *__restrict__ cache_seqlens;
|
37 |
+
|
38 |
+
void *__restrict__ seq_idx_ptr;
|
39 |
+
|
40 |
+
// No __restrict__ since initial_states could be the same as final_states.
|
41 |
+
void * initial_states_ptr;
|
42 |
+
index_t initial_states_batch_stride;
|
43 |
+
index_t initial_states_l_stride;
|
44 |
+
index_t initial_states_c_stride;
|
45 |
+
|
46 |
+
void * final_states_ptr;
|
47 |
+
index_t final_states_batch_stride;
|
48 |
+
index_t final_states_l_stride;
|
49 |
+
index_t final_states_c_stride;
|
50 |
+
};
|
51 |
+
|
52 |
+
struct ConvParamsBwd: public ConvParamsBase {
|
53 |
+
index_t dx_batch_stride;
|
54 |
+
index_t dx_c_stride;
|
55 |
+
index_t dx_l_stride;
|
56 |
+
index_t dweight_c_stride;
|
57 |
+
index_t dweight_width_stride;
|
58 |
+
index_t dout_batch_stride;
|
59 |
+
index_t dout_c_stride;
|
60 |
+
index_t dout_l_stride;
|
61 |
+
|
62 |
+
// Common data pointers.
|
63 |
+
void *__restrict__ dx_ptr;
|
64 |
+
void *__restrict__ dweight_ptr;
|
65 |
+
void *__restrict__ dbias_ptr;
|
66 |
+
void *__restrict__ dout_ptr;
|
67 |
+
|
68 |
+
void * dinitial_states_ptr;
|
69 |
+
index_t dinitial_states_batch_stride;
|
70 |
+
index_t dinitial_states_l_stride;
|
71 |
+
index_t dinitial_states_c_stride;
|
72 |
+
|
73 |
+
void * dfinal_states_ptr;
|
74 |
+
index_t dfinal_states_batch_stride;
|
75 |
+
index_t dfinal_states_l_stride;
|
76 |
+
index_t dfinal_states_c_stride;
|
77 |
+
};
|
causal-conv1d/csrc/causal_conv1d_bwd.cu
ADDED
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#include <c10/util/BFloat16.h>
|
6 |
+
#include <c10/util/Half.h>
|
7 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
8 |
+
|
9 |
+
#ifndef USE_ROCM
|
10 |
+
#include <cub/block/block_load.cuh>
|
11 |
+
#include <cub/block/block_store.cuh>
|
12 |
+
#include <cub/block/block_reduce.cuh>
|
13 |
+
#else
|
14 |
+
#include <hipcub/hipcub.hpp>
|
15 |
+
namespace cub = hipcub;
|
16 |
+
#endif
|
17 |
+
|
18 |
+
#include "causal_conv1d.h"
|
19 |
+
#include "causal_conv1d_common.h"
|
20 |
+
#include "static_switch.h"
|
21 |
+
|
22 |
+
template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
23 |
+
struct Causal_conv1d_bwd_kernel_traits {
|
24 |
+
using input_t = input_t_;
|
25 |
+
using weight_t = weight_t_;
|
26 |
+
static constexpr int kNThreads = kNThreads_;
|
27 |
+
static constexpr int kWidth = kWidth_;
|
28 |
+
static constexpr bool kSiluAct = kSiluAct_;
|
29 |
+
static constexpr int kNBytes = sizeof(input_t);
|
30 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
31 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
32 |
+
static_assert(kWidth <= kNElts);
|
33 |
+
// It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
|
34 |
+
// (since then we'd have 8 values of float, and each round we can exchange 4 floats).
|
35 |
+
static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
|
36 |
+
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
37 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
38 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
39 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
40 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
41 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
42 |
+
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
|
43 |
+
static constexpr int kSmemIOSize = kIsVecLoad
|
44 |
+
? 0
|
45 |
+
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
46 |
+
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1);
|
47 |
+
static constexpr int kSmemSize = custom_max({kSmemExchangeSize,
|
48 |
+
int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize);
|
49 |
+
};
|
50 |
+
|
51 |
+
template<typename Ktraits>
|
52 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
53 |
+
void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
|
54 |
+
constexpr int kWidth = Ktraits::kWidth;
|
55 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
56 |
+
constexpr bool kSiluAct = Ktraits::kSiluAct;
|
57 |
+
static constexpr int kNElts = Ktraits::kNElts;
|
58 |
+
constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
|
59 |
+
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
60 |
+
using input_t = typename Ktraits::input_t;
|
61 |
+
using vec_t = typename Ktraits::vec_t;
|
62 |
+
using weight_t = typename Ktraits::weight_t;
|
63 |
+
|
64 |
+
// Shared memory.
|
65 |
+
extern __shared__ char smem_[];
|
66 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
67 |
+
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
68 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
69 |
+
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
70 |
+
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
71 |
+
vec_t *smem_exchange_x = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds;
|
72 |
+
auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
73 |
+
|
74 |
+
const int tidx = threadIdx.x;
|
75 |
+
const int batch_id = blockIdx.x;
|
76 |
+
const int dim_id = blockIdx.y;
|
77 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
78 |
+
+ dim_id * params.x_c_stride;
|
79 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + dim_id * params.weight_c_stride;
|
80 |
+
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
81 |
+
+ dim_id * params.dout_c_stride;
|
82 |
+
input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
|
83 |
+
+ dim_id * params.dx_c_stride;
|
84 |
+
float *dweight = reinterpret_cast<float *>(params.dweight_ptr) + dim_id * params.dweight_c_stride;
|
85 |
+
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[dim_id]);
|
86 |
+
|
87 |
+
// Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0.
|
88 |
+
if (tidx == 0) {
|
89 |
+
if constexpr (!kSiluAct) {
|
90 |
+
input_t zeros[kNElts] = {0};
|
91 |
+
smem_exchange[0] = reinterpret_cast<vec_t *>(zeros)[0];
|
92 |
+
} else {
|
93 |
+
float zeros[kNElts] = {0};
|
94 |
+
#pragma unroll
|
95 |
+
for (int r = 0; r < kNExchangeRounds; ++r) {
|
96 |
+
smem_exchange[r * kNThreads] = reinterpret_cast<vec_t *>(zeros)[r];
|
97 |
+
}
|
98 |
+
}
|
99 |
+
}
|
100 |
+
|
101 |
+
float weight_vals[kWidth];
|
102 |
+
#pragma unroll
|
103 |
+
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; }
|
104 |
+
|
105 |
+
float dweight_vals[kWidth] = {0};
|
106 |
+
float dbias_val = 0;
|
107 |
+
|
108 |
+
constexpr int kChunkSize = kNThreads * kNElts;
|
109 |
+
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
110 |
+
x += (n_chunks - 1) * kChunkSize;
|
111 |
+
dout += (n_chunks - 1) * kChunkSize;
|
112 |
+
dx += (n_chunks - 1) * kChunkSize;
|
113 |
+
for (int chunk = n_chunks - 1; chunk >= 0; --chunk) {
|
114 |
+
input_t x_vals_load[2 * kNElts] = {0};
|
115 |
+
input_t dout_vals_load[2 * kNElts] = {0};
|
116 |
+
if constexpr(kIsVecLoad) {
|
117 |
+
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
118 |
+
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(dout), *reinterpret_cast<vec_t (*)[1]>(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
119 |
+
} else {
|
120 |
+
__syncthreads();
|
121 |
+
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
122 |
+
__syncthreads();
|
123 |
+
typename Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast<input_t (*)[kNElts]>(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize);
|
124 |
+
}
|
125 |
+
float dout_vals[2 * kNElts], x_vals[2 * kNElts];
|
126 |
+
if constexpr (!kSiluAct) {
|
127 |
+
__syncthreads();
|
128 |
+
// Thread 0 don't write yet, so that thread kNThreads - 1 can read
|
129 |
+
// the first elements of the next chunk.
|
130 |
+
if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
|
131 |
+
__syncthreads();
|
132 |
+
reinterpret_cast<vec_t *>(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0];
|
133 |
+
__syncthreads();
|
134 |
+
// Now thread 0 can write the first elements of the current chunk.
|
135 |
+
if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
|
136 |
+
#pragma unroll
|
137 |
+
for (int i = 0; i < 2 * kNElts; ++i) {
|
138 |
+
dout_vals[i] = float(dout_vals_load[i]);
|
139 |
+
x_vals[i] = float(x_vals_load[i]);
|
140 |
+
}
|
141 |
+
} else {
|
142 |
+
if (tidx == 0 && chunk > 0) {
|
143 |
+
if constexpr(kIsVecLoad) {
|
144 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = reinterpret_cast<vec_t *>(x)[-1];
|
145 |
+
} else {
|
146 |
+
#pragma unroll
|
147 |
+
for (int i = 0; i < kNElts; ++i) {
|
148 |
+
if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; }
|
149 |
+
}
|
150 |
+
}
|
151 |
+
}
|
152 |
+
__syncthreads();
|
153 |
+
smem_exchange_x[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1];
|
154 |
+
__syncthreads();
|
155 |
+
if (tidx > 0) { reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange_x[tidx - 1]; }
|
156 |
+
#pragma unroll
|
157 |
+
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
158 |
+
// Recompute the output
|
159 |
+
#pragma unroll
|
160 |
+
for (int i = 0; i < kNElts; ++i) {
|
161 |
+
float out_val = bias_val;
|
162 |
+
#pragma unroll
|
163 |
+
for (int w = 0; w < kWidth; ++w) {
|
164 |
+
out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
165 |
+
}
|
166 |
+
float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val));
|
167 |
+
dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
|
168 |
+
* (1.0f + out_val * (1.0f - out_sigmoid_val));
|
169 |
+
}
|
170 |
+
// Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
|
171 |
+
// if input_t is 16 bits (since then we'd have 8 values of float)
|
172 |
+
__syncthreads();
|
173 |
+
// Thread 0 don't write yet, so that thread kNThreads - 1 can read
|
174 |
+
// the first elements of the next chunk.
|
175 |
+
if (tidx > 0) {
|
176 |
+
#pragma unroll
|
177 |
+
for (int r = 0; r < kNExchangeRounds; ++r) {
|
178 |
+
smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
|
179 |
+
}
|
180 |
+
}
|
181 |
+
__syncthreads();
|
182 |
+
#pragma unroll
|
183 |
+
for (int r = 0; r < kNExchangeRounds; ++r) {
|
184 |
+
reinterpret_cast<vec_t *>(dout_vals)[kNExchangeRounds + r]
|
185 |
+
= smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
|
186 |
+
}
|
187 |
+
__syncthreads();
|
188 |
+
// Now thread 0 can write the first elements of the current chunk.
|
189 |
+
if (tidx == 0) {
|
190 |
+
#pragma unroll
|
191 |
+
for (int r = 0; r < kNExchangeRounds; ++r) {
|
192 |
+
smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
|
193 |
+
}
|
194 |
+
}
|
195 |
+
}
|
196 |
+
dout -= kChunkSize;
|
197 |
+
x -= kChunkSize;
|
198 |
+
|
199 |
+
#pragma unroll
|
200 |
+
for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; }
|
201 |
+
|
202 |
+
float dx_vals[kNElts] = {0};
|
203 |
+
#pragma unroll
|
204 |
+
for (int i = 0; i < kNElts; ++i) {
|
205 |
+
#pragma unroll
|
206 |
+
for (int w = 0; w < kWidth; ++w) {
|
207 |
+
dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1];
|
208 |
+
}
|
209 |
+
}
|
210 |
+
|
211 |
+
input_t dx_vals_store[kNElts];
|
212 |
+
#pragma unroll
|
213 |
+
for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; }
|
214 |
+
if constexpr(kIsVecLoad) {
|
215 |
+
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(dx), reinterpret_cast<vec_t (&)[1]>(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
216 |
+
} else {
|
217 |
+
typename Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize);
|
218 |
+
}
|
219 |
+
dx -= kChunkSize;
|
220 |
+
|
221 |
+
#pragma unroll
|
222 |
+
for (int w = 0; w < kWidth; ++w) {
|
223 |
+
#pragma unroll
|
224 |
+
for (int i = 0; i < kNElts; ++i) {
|
225 |
+
dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1];
|
226 |
+
}
|
227 |
+
}
|
228 |
+
}
|
229 |
+
|
230 |
+
#pragma unroll
|
231 |
+
for (int w = 0; w < kWidth; ++w) {
|
232 |
+
__syncthreads();
|
233 |
+
dweight_vals[w] = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]);
|
234 |
+
if (tidx == 0) {
|
235 |
+
atomicAdd(&reinterpret_cast<float *>(dweight)[w * params.dweight_width_stride], dweight_vals[w]);
|
236 |
+
}
|
237 |
+
}
|
238 |
+
if (params.bias_ptr != nullptr) {
|
239 |
+
__syncthreads();
|
240 |
+
dbias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val);
|
241 |
+
if (tidx == 0) {
|
242 |
+
atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[dim_id], dbias_val);
|
243 |
+
}
|
244 |
+
}
|
245 |
+
}
|
246 |
+
|
247 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
248 |
+
void causal_conv1d_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
249 |
+
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
250 |
+
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
251 |
+
BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
|
252 |
+
using Ktraits = Causal_conv1d_bwd_kernel_traits<kNThreads, kWidth, kSiluAct, kIsVecLoad, input_t, weight_t>;
|
253 |
+
constexpr int kSmemSize = Ktraits::kSmemSize;
|
254 |
+
dim3 grid(params.batch, params.dim);
|
255 |
+
auto kernel = &causal_conv1d_bwd_kernel<Ktraits>;
|
256 |
+
|
257 |
+
if (kSmemSize >= 48 * 1024) {
|
258 |
+
#ifndef USE_ROCM
|
259 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
260 |
+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
261 |
+
#else
|
262 |
+
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
263 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
264 |
+
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
265 |
+
std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
266 |
+
#endif
|
267 |
+
}
|
268 |
+
|
269 |
+
|
270 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
271 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
272 |
+
});
|
273 |
+
});
|
274 |
+
}
|
275 |
+
|
276 |
+
template<typename input_t, typename weight_t>
|
277 |
+
void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
278 |
+
if (params.width == 2) {
|
279 |
+
causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream);
|
280 |
+
} else if (params.width == 3) {
|
281 |
+
causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream);
|
282 |
+
} else if (params.width == 4) {
|
283 |
+
causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream);
|
284 |
+
}
|
285 |
+
}
|
286 |
+
|
287 |
+
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
288 |
+
struct Causal_conv1d_channellast_bwd_kernel_traits {
|
289 |
+
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
290 |
+
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
291 |
+
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
292 |
+
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
293 |
+
using input_t = input_t_;
|
294 |
+
using weight_t = weight_t_;
|
295 |
+
static constexpr bool kSiluAct = kSiluAct_;
|
296 |
+
static constexpr int kNThreads = kNThreads_;
|
297 |
+
static_assert(kNThreads % 32 == 0);
|
298 |
+
static constexpr int kNWarps = kNThreads / 32;
|
299 |
+
static constexpr int kWidth = kWidth_;
|
300 |
+
static constexpr int kChunkSizeL = kChunkSizeL_;
|
301 |
+
static constexpr int kNBytes = sizeof(input_t);
|
302 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
303 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
304 |
+
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
305 |
+
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
306 |
+
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
307 |
+
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
308 |
+
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
309 |
+
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
310 |
+
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
311 |
+
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
312 |
+
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
313 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
314 |
+
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
315 |
+
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
316 |
+
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
317 |
+
// sizeof(typename BlockStoreT::TempStorage)});
|
318 |
+
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
319 |
+
};
|
320 |
+
|
321 |
+
template<typename Ktraits, bool kHasSeqIdx, bool kHasDfinalStates>
|
322 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
323 |
+
void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
|
324 |
+
constexpr int kWidth = Ktraits::kWidth;
|
325 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
326 |
+
constexpr bool kSiluAct = Ktraits::kSiluAct;
|
327 |
+
constexpr int kNElts = Ktraits::kNElts;
|
328 |
+
constexpr int kNWarp = Ktraits::kNWarps;
|
329 |
+
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
330 |
+
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
331 |
+
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
332 |
+
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
333 |
+
using input_t = typename Ktraits::input_t;
|
334 |
+
using vec_t = typename Ktraits::vec_t;
|
335 |
+
using weight_t = typename Ktraits::weight_t;
|
336 |
+
|
337 |
+
// Shared memory.
|
338 |
+
__shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
|
339 |
+
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
|
340 |
+
|
341 |
+
const int batch_id = blockIdx.x;
|
342 |
+
const int chunk_l_id = blockIdx.y;
|
343 |
+
const int chunk_c_id = blockIdx.z;
|
344 |
+
const int tid = threadIdx.x;
|
345 |
+
const int l_idx = tid / kNThreadsPerC;
|
346 |
+
const int c_idx = tid % kNThreadsPerC;
|
347 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
348 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
349 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
350 |
+
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
351 |
+
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
352 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
353 |
+
input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
|
354 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
355 |
+
float *dweight = reinterpret_cast<float *>(params.dweight_ptr)
|
356 |
+
+ chunk_c_id * kChunkSizeC * params.dweight_c_stride;
|
357 |
+
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
358 |
+
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
359 |
+
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
360 |
+
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
361 |
+
input_t *dinitial_states = params.dinitial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
362 |
+
: reinterpret_cast<input_t *>(params.dinitial_states_ptr) + batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
363 |
+
input_t *dfinal_states = params.dfinal_states_ptr == nullptr ? nullptr
|
364 |
+
: reinterpret_cast<input_t *>(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + chunk_c_id * kChunkSizeC;
|
365 |
+
|
366 |
+
#pragma unroll
|
367 |
+
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
368 |
+
input_t dout_vals_load[kNElts] = {0};
|
369 |
+
input_t x_vals_load[kNElts] = {0};
|
370 |
+
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
371 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
372 |
+
reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + l * kLPerLoad * params.dout_l_stride);
|
373 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
374 |
+
}
|
375 |
+
reinterpret_cast<vec_t *>(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
|
376 |
+
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
377 |
+
}
|
378 |
+
// Load the elements from the previous chunk or next chunk that are needed for convolution.
|
379 |
+
if (l_idx < kWidth - 1) {
|
380 |
+
input_t dout_vals_load[kNElts] = {0};
|
381 |
+
input_t x_vals_load[kNElts] = {0};
|
382 |
+
if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
|
383 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
384 |
+
reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + kChunkSizeL * params.dout_l_stride);
|
385 |
+
}
|
386 |
+
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
387 |
+
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
388 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
389 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
390 |
+
} else if (initial_states != nullptr
|
391 |
+
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
392 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
393 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
394 |
+
}
|
395 |
+
reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
|
396 |
+
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
397 |
+
}
|
398 |
+
// Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
|
399 |
+
if constexpr (kSiluAct) {
|
400 |
+
if (l_idx < kWidth - 1) {
|
401 |
+
input_t x_vals_load[kNElts] = {0};
|
402 |
+
if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
|
403 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
404 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + kChunkSizeL * params.x_l_stride);
|
405 |
+
}
|
406 |
+
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
407 |
+
}
|
408 |
+
}
|
409 |
+
|
410 |
+
__syncthreads();
|
411 |
+
|
412 |
+
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
413 |
+
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
414 |
+
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
415 |
+
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
416 |
+
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
417 |
+
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
418 |
+
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
419 |
+
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
420 |
+
static_assert(kNThreadsPerRow <= 32);
|
421 |
+
|
422 |
+
const int row_idx = tid / kNThreadsPerRow;
|
423 |
+
const int col_idx = tid % kNThreadsPerRow;
|
424 |
+
|
425 |
+
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
426 |
+
float weight_vals[kWidth] = {0};
|
427 |
+
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
428 |
+
#pragma unroll
|
429 |
+
for (int w = 0; w < kWidth; ++w) {
|
430 |
+
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
431 |
+
}
|
432 |
+
}
|
433 |
+
float dout_vals[kLPerThread + kWidth - 1];
|
434 |
+
float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
|
435 |
+
#pragma unroll
|
436 |
+
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
437 |
+
dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
|
438 |
+
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
439 |
+
}
|
440 |
+
|
441 |
+
int seq_idx_thread[kWidth - 1 + kLPerThread + kWidth - 1];
|
442 |
+
if constexpr (kHasSeqIdx) {
|
443 |
+
#pragma unroll
|
444 |
+
for (int i = 0; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
|
445 |
+
const int l_idx = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1);
|
446 |
+
seq_idx_thread[i] = l_idx >= 0 && l_idx < params.seqlen ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
447 |
+
}
|
448 |
+
}
|
449 |
+
|
450 |
+
if constexpr (kSiluAct) { // Recompute the output
|
451 |
+
#pragma unroll
|
452 |
+
for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
|
453 |
+
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
454 |
+
}
|
455 |
+
#pragma unroll
|
456 |
+
for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
|
457 |
+
float out_val = bias_val;
|
458 |
+
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
459 |
+
#pragma unroll
|
460 |
+
for (int w = 0; w < kWidth; ++w) {
|
461 |
+
if constexpr (!kHasSeqIdx) {
|
462 |
+
out_val += weight_vals[w] * x_vals[i + w];
|
463 |
+
} else {
|
464 |
+
out_val += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
465 |
+
}
|
466 |
+
}
|
467 |
+
float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
|
468 |
+
dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
|
469 |
+
}
|
470 |
+
}
|
471 |
+
|
472 |
+
float dweight_vals[kWidth] = {0};
|
473 |
+
SumOp<float> sum_op;
|
474 |
+
#pragma unroll
|
475 |
+
for (int w = 0; w < kWidth; ++w) {
|
476 |
+
#pragma unroll
|
477 |
+
for (int i = 0; i < kLPerThread; ++i) {
|
478 |
+
if constexpr (!kHasSeqIdx) {
|
479 |
+
dweight_vals[w] += x_vals[i + w] * dout_vals[i];
|
480 |
+
} else {
|
481 |
+
dweight_vals[w] += seq_idx_thread[i + w] == seq_idx_thread[kWidth - 1 + i] ? x_vals[i + w] * dout_vals[i] : 0.f;
|
482 |
+
}
|
483 |
+
}
|
484 |
+
dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op);
|
485 |
+
if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
486 |
+
atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
|
487 |
+
}
|
488 |
+
}
|
489 |
+
|
490 |
+
if (params.bias_ptr != nullptr) {
|
491 |
+
float dbias_val = 0.f;
|
492 |
+
for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
|
493 |
+
dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op);
|
494 |
+
if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
495 |
+
atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
|
496 |
+
}
|
497 |
+
}
|
498 |
+
|
499 |
+
float dx_vals[kLPerThread] = {0};
|
500 |
+
#pragma unroll
|
501 |
+
for (int i = 0; i < kLPerThread; ++i) {
|
502 |
+
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
503 |
+
#pragma unroll
|
504 |
+
for (int w = 0; w < kWidth; ++w) {
|
505 |
+
if constexpr (!kHasSeqIdx) {
|
506 |
+
dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w];
|
507 |
+
} else {
|
508 |
+
dx_vals[i] += seq_idx_thread[kWidth - 1 + i + w] == seq_idx_cur ? weight_vals[kWidth - 1 - w] * dout_vals[i + w] : 0.f;
|
509 |
+
}
|
510 |
+
}
|
511 |
+
// if (dfinal_states != nullptr) {
|
512 |
+
if constexpr (kHasDfinalStates) {
|
513 |
+
if (chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i >= params.seqlen - kWidth + 1
|
514 |
+
&& chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i < params.seqlen
|
515 |
+
&& chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
516 |
+
dx_vals[i] += float(dfinal_states[((chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i) - (params.seqlen - kWidth + 1)) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
|
517 |
+
}
|
518 |
+
}
|
519 |
+
}
|
520 |
+
|
521 |
+
float dxinit_vals[kWidth - 1] = {0};
|
522 |
+
static_assert(kLPerThread >= kWidth - 1); // So only threads with col_idx == 0 need to handle dinitial_states
|
523 |
+
if (dinitial_states != nullptr && col_idx == 0) {
|
524 |
+
#pragma unroll
|
525 |
+
for (int i = 0; i < kWidth - 1; ++i) {
|
526 |
+
#pragma unroll
|
527 |
+
for (int w = 0; w < kWidth; ++w) {
|
528 |
+
dxinit_vals[i] += i + w - (kWidth - 1) >= 0 ? weight_vals[kWidth - 1 - w] * dout_vals[i + w - (kWidth - 1)] : 0.f;
|
529 |
+
}
|
530 |
+
// chunk_l_id must be 0 because dinitial_states != nullptr
|
531 |
+
// if (dfinal_states != nullptr) {
|
532 |
+
if constexpr (kHasDfinalStates) {
|
533 |
+
if (i >= params.seqlen) {
|
534 |
+
dxinit_vals[i] += float(dfinal_states[(i - params.seqlen) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
|
535 |
+
}
|
536 |
+
}
|
537 |
+
}
|
538 |
+
}
|
539 |
+
|
540 |
+
__syncthreads();
|
541 |
+
#pragma unroll
|
542 |
+
for (int i = 0; i < kLPerThread; ++i) { x_smem[kWidth - 1 + col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
|
543 |
+
if (dinitial_states != nullptr && col_idx == 0) {
|
544 |
+
#pragma unroll
|
545 |
+
for (int i = 0; i < kWidth - 1; ++i) { x_smem[i][row_idx] = dxinit_vals[i]; }
|
546 |
+
}
|
547 |
+
__syncthreads();
|
548 |
+
|
549 |
+
#pragma unroll
|
550 |
+
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
551 |
+
input_t dx_vals_store[kNElts];
|
552 |
+
reinterpret_cast<vec_t *>(dx_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx];
|
553 |
+
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
554 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
555 |
+
*reinterpret_cast<vec_t *>(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast<vec_t *>(dx_vals_store)[0];
|
556 |
+
}
|
557 |
+
}
|
558 |
+
if (dinitial_states != nullptr
|
559 |
+
&& l_idx < kWidth - 1
|
560 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
561 |
+
input_t dxinit_vals_store[kNElts];
|
562 |
+
reinterpret_cast<vec_t *>(dxinit_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx];
|
563 |
+
*reinterpret_cast<vec_t *>(dinitial_states) = reinterpret_cast<vec_t *>(dxinit_vals_store)[0];
|
564 |
+
}
|
565 |
+
|
566 |
+
}
|
567 |
+
|
568 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
569 |
+
void causal_conv1d_channellast_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
570 |
+
BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
|
571 |
+
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
572 |
+
BOOL_SWITCH(params.dfinal_states_ptr != nullptr, kHasDfinalStates, [&] {
|
573 |
+
BOOL_SWITCH(params.seqlen <= 128, kChunkSizeL64, [&] {
|
574 |
+
// kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger
|
575 |
+
static constexpr int kChunk = kChunkSizeL64 ? 64 : 128;
|
576 |
+
using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, kChunk, kSiluAct, true, input_t, weight_t>;
|
577 |
+
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
578 |
+
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
579 |
+
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
580 |
+
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
581 |
+
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
582 |
+
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
583 |
+
dim3 block(Ktraits::kNThreads);
|
584 |
+
auto kernel = &causal_conv1d_channellast_bwd_kernel<Ktraits, kHasSeqIdx, kHasDfinalStates>;
|
585 |
+
// if (kSmemSize >= 48 * 1024) {
|
586 |
+
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
587 |
+
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
588 |
+
// }
|
589 |
+
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
590 |
+
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
591 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
592 |
+
});
|
593 |
+
});
|
594 |
+
});
|
595 |
+
});
|
596 |
+
}
|
597 |
+
|
598 |
+
template<typename input_t, typename weight_t>
|
599 |
+
void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
600 |
+
if (params.width == 2) {
|
601 |
+
causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream);
|
602 |
+
} else if (params.width == 3) {
|
603 |
+
causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream);
|
604 |
+
} else if (params.width == 4) {
|
605 |
+
causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream);
|
606 |
+
}
|
607 |
+
}
|
608 |
+
|
609 |
+
template void causal_conv1d_bwd_cuda<float, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
610 |
+
template void causal_conv1d_bwd_cuda<at::Half, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
611 |
+
template void causal_conv1d_bwd_cuda<at::BFloat16, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
612 |
+
template void causal_conv1d_bwd_cuda<float, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
613 |
+
template void causal_conv1d_bwd_cuda<at::Half, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
614 |
+
template void causal_conv1d_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
615 |
+
template void causal_conv1d_bwd_cuda<float, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
616 |
+
template void causal_conv1d_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
617 |
+
template void causal_conv1d_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
618 |
+
|
619 |
+
template void causal_conv1d_channellast_bwd_cuda<float, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
620 |
+
template void causal_conv1d_channellast_bwd_cuda<at::Half, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
621 |
+
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
622 |
+
template void causal_conv1d_channellast_bwd_cuda<float, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
623 |
+
template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
624 |
+
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
625 |
+
template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
626 |
+
template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
627 |
+
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
causal-conv1d/csrc/causal_conv1d_common.h
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#ifndef USE_ROCM
|
8 |
+
#include <cuda_bf16.h>
|
9 |
+
|
10 |
+
template<typename T>
|
11 |
+
__device__ inline T shuffle_xor(T val, int offset) {
|
12 |
+
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
13 |
+
}
|
14 |
+
|
15 |
+
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
16 |
+
{
|
17 |
+
return std::max(ilist);
|
18 |
+
}
|
19 |
+
|
20 |
+
template<typename T>
|
21 |
+
constexpr T constexpr_min(T a, T b) {
|
22 |
+
return std::min(a, b);
|
23 |
+
}
|
24 |
+
|
25 |
+
#else
|
26 |
+
#include <hip/hip_bf16.h>
|
27 |
+
|
28 |
+
template<typename T>
|
29 |
+
__device__ inline T shuffle_xor(T val, int offset) {
|
30 |
+
return __shfl_xor(val, offset);
|
31 |
+
}
|
32 |
+
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
33 |
+
{
|
34 |
+
return *std::max_element(ilist.begin(), ilist.end());
|
35 |
+
}
|
36 |
+
|
37 |
+
template<typename T>
|
38 |
+
constexpr T constexpr_min(T a, T b) {
|
39 |
+
return a < b ? a : b;
|
40 |
+
}
|
41 |
+
#endif
|
42 |
+
#include <cuda_fp16.h>
|
43 |
+
|
44 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
45 |
+
|
46 |
+
template<int BYTES> struct BytesToType {};
|
47 |
+
|
48 |
+
template<> struct BytesToType<16> {
|
49 |
+
using Type = uint4;
|
50 |
+
static_assert(sizeof(Type) == 16);
|
51 |
+
};
|
52 |
+
|
53 |
+
template<> struct BytesToType<8> {
|
54 |
+
using Type = uint64_t;
|
55 |
+
static_assert(sizeof(Type) == 8);
|
56 |
+
};
|
57 |
+
|
58 |
+
template<> struct BytesToType<4> {
|
59 |
+
using Type = uint32_t;
|
60 |
+
static_assert(sizeof(Type) == 4);
|
61 |
+
};
|
62 |
+
|
63 |
+
template<> struct BytesToType<2> {
|
64 |
+
using Type = uint16_t;
|
65 |
+
static_assert(sizeof(Type) == 2);
|
66 |
+
};
|
67 |
+
|
68 |
+
template<> struct BytesToType<1> {
|
69 |
+
using Type = uint8_t;
|
70 |
+
static_assert(sizeof(Type) == 1);
|
71 |
+
};
|
72 |
+
|
73 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
74 |
+
|
75 |
+
template<typename T>
|
76 |
+
struct SumOp {
|
77 |
+
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
78 |
+
};
|
79 |
+
|
80 |
+
template<int THREADS>
|
81 |
+
struct Allreduce {
|
82 |
+
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
83 |
+
template<typename T, typename Operator>
|
84 |
+
static __device__ inline T run(T x, Operator &op) {
|
85 |
+
constexpr int OFFSET = THREADS / 2;
|
86 |
+
x = op(x, shuffle_xor(x, OFFSET));
|
87 |
+
return Allreduce<OFFSET>::run(x, op);
|
88 |
+
}
|
89 |
+
};
|
90 |
+
|
91 |
+
template<>
|
92 |
+
struct Allreduce<2> {
|
93 |
+
template<typename T, typename Operator>
|
94 |
+
static __device__ inline T run(T x, Operator &op) {
|
95 |
+
x = op(x, shuffle_xor(x, 1));
|
96 |
+
return x;
|
97 |
+
}
|
98 |
+
};
|
causal-conv1d/csrc/causal_conv1d_fwd.cu
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#include <c10/util/BFloat16.h>
|
6 |
+
#include <c10/util/Half.h>
|
7 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
8 |
+
|
9 |
+
#ifndef USE_ROCM
|
10 |
+
#include <cub/block/block_load.cuh>
|
11 |
+
#include <cub/block/block_store.cuh>
|
12 |
+
#else
|
13 |
+
#include <hipcub/hipcub.hpp>
|
14 |
+
namespace cub = hipcub;
|
15 |
+
#endif
|
16 |
+
|
17 |
+
#include "causal_conv1d.h"
|
18 |
+
#include "causal_conv1d_common.h"
|
19 |
+
#include "static_switch.h"
|
20 |
+
|
21 |
+
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
22 |
+
struct Causal_conv1d_fwd_kernel_traits {
|
23 |
+
using input_t = input_t_;
|
24 |
+
using weight_t = weight_t_;
|
25 |
+
static constexpr int kNThreads = kNThreads_;
|
26 |
+
static constexpr int kWidth = kWidth_;
|
27 |
+
static constexpr int kNBytes = sizeof(input_t);
|
28 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
29 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
30 |
+
static_assert(kWidth <= kNElts);
|
31 |
+
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
32 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
33 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
34 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
35 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
36 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
37 |
+
static constexpr int kSmemIOSize = kIsVecLoad
|
38 |
+
? 0
|
39 |
+
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
40 |
+
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
41 |
+
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
42 |
+
};
|
43 |
+
|
44 |
+
template<typename Ktraits>
|
45 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
46 |
+
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
47 |
+
constexpr int kWidth = Ktraits::kWidth;
|
48 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
49 |
+
constexpr int kNElts = Ktraits::kNElts;
|
50 |
+
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
51 |
+
using input_t = typename Ktraits::input_t;
|
52 |
+
using vec_t = typename Ktraits::vec_t;
|
53 |
+
using weight_t = typename Ktraits::weight_t;
|
54 |
+
|
55 |
+
// Shared memory.
|
56 |
+
extern __shared__ char smem_[];
|
57 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
58 |
+
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
59 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
60 |
+
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
61 |
+
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
62 |
+
|
63 |
+
const int tidx = threadIdx.x;
|
64 |
+
const int batch_id = blockIdx.x;
|
65 |
+
const int channel_id = blockIdx.y;
|
66 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
67 |
+
+ channel_id * params.x_c_stride;
|
68 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
69 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
70 |
+
+ channel_id * params.out_c_stride;
|
71 |
+
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
72 |
+
|
73 |
+
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
74 |
+
if (tidx == 0) {
|
75 |
+
input_t zeros[kNElts] = {0};
|
76 |
+
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
|
77 |
+
}
|
78 |
+
|
79 |
+
float weight_vals[kWidth];
|
80 |
+
#pragma unroll
|
81 |
+
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
82 |
+
|
83 |
+
constexpr int kChunkSize = kNThreads * kNElts;
|
84 |
+
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
85 |
+
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
86 |
+
input_t x_vals_load[2 * kNElts] = {0};
|
87 |
+
if constexpr(kIsVecLoad) {
|
88 |
+
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
89 |
+
} else {
|
90 |
+
__syncthreads();
|
91 |
+
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
92 |
+
}
|
93 |
+
x += kChunkSize;
|
94 |
+
__syncthreads();
|
95 |
+
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
96 |
+
// the last elements of the previous chunk.
|
97 |
+
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
98 |
+
__syncthreads();
|
99 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
100 |
+
__syncthreads();
|
101 |
+
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
102 |
+
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
103 |
+
|
104 |
+
float x_vals[2 * kNElts];
|
105 |
+
#pragma unroll
|
106 |
+
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
107 |
+
|
108 |
+
float out_vals[kNElts];
|
109 |
+
#pragma unroll
|
110 |
+
for (int i = 0; i < kNElts; ++i) {
|
111 |
+
out_vals[i] = bias_val;
|
112 |
+
#pragma unroll
|
113 |
+
for (int w = 0; w < kWidth; ++w) {
|
114 |
+
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
115 |
+
}
|
116 |
+
}
|
117 |
+
|
118 |
+
if (params.silu_activation) {
|
119 |
+
#pragma unroll
|
120 |
+
for (int i = 0; i < kNElts; ++i) {
|
121 |
+
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
122 |
+
}
|
123 |
+
}
|
124 |
+
|
125 |
+
input_t out_vals_store[kNElts];
|
126 |
+
#pragma unroll
|
127 |
+
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
128 |
+
if constexpr(kIsVecLoad) {
|
129 |
+
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
130 |
+
} else {
|
131 |
+
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
|
132 |
+
}
|
133 |
+
out += kChunkSize;
|
134 |
+
}
|
135 |
+
}
|
136 |
+
|
137 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
138 |
+
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
139 |
+
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
140 |
+
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
141 |
+
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
142 |
+
constexpr int kSmemSize = Ktraits::kSmemSize;
|
143 |
+
dim3 grid(params.batch, params.dim);
|
144 |
+
|
145 |
+
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
146 |
+
|
147 |
+
if (kSmemSize >= 48 * 1024) {
|
148 |
+
#ifndef USE_ROCM
|
149 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
150 |
+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
151 |
+
#else
|
152 |
+
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
153 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
154 |
+
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
155 |
+
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
156 |
+
#endif
|
157 |
+
}
|
158 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
159 |
+
|
160 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
161 |
+
});
|
162 |
+
}
|
163 |
+
|
164 |
+
template<typename input_t, typename weight_t>
|
165 |
+
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
166 |
+
if (params.width == 2) {
|
167 |
+
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
168 |
+
} else if (params.width == 3) {
|
169 |
+
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
170 |
+
} else if (params.width == 4) {
|
171 |
+
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
172 |
+
}
|
173 |
+
}
|
174 |
+
|
175 |
+
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
176 |
+
struct Causal_conv1d_channellast_fwd_kernel_traits {
|
177 |
+
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
178 |
+
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
179 |
+
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
180 |
+
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
181 |
+
using input_t = input_t_;
|
182 |
+
using weight_t = weight_t_;
|
183 |
+
static constexpr int kNThreads = kNThreads_;
|
184 |
+
static_assert(kNThreads % 32 == 0);
|
185 |
+
static constexpr int kNWarps = kNThreads / 32;
|
186 |
+
static constexpr int kWidth = kWidth_;
|
187 |
+
static constexpr int kChunkSizeL = kChunkSizeL_;
|
188 |
+
static constexpr int kNBytes = sizeof(input_t);
|
189 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
190 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
191 |
+
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
192 |
+
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
193 |
+
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
194 |
+
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
195 |
+
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
196 |
+
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
197 |
+
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
198 |
+
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
199 |
+
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
200 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
201 |
+
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
202 |
+
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
203 |
+
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
204 |
+
// sizeof(typename BlockStoreT::TempStorage)});
|
205 |
+
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
206 |
+
};
|
207 |
+
|
208 |
+
template<typename Ktraits, bool kHasSeqIdx>
|
209 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
210 |
+
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
|
211 |
+
constexpr int kWidth = Ktraits::kWidth;
|
212 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
213 |
+
constexpr int kNElts = Ktraits::kNElts;
|
214 |
+
constexpr int kNWarp = Ktraits::kNWarps;
|
215 |
+
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
216 |
+
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
217 |
+
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
218 |
+
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
219 |
+
using input_t = typename Ktraits::input_t;
|
220 |
+
using vec_t = typename Ktraits::vec_t;
|
221 |
+
using weight_t = typename Ktraits::weight_t;
|
222 |
+
|
223 |
+
// Shared memory.
|
224 |
+
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
|
225 |
+
|
226 |
+
const int batch_id = blockIdx.x;
|
227 |
+
const int chunk_l_id = blockIdx.y;
|
228 |
+
const int chunk_c_id = blockIdx.z;
|
229 |
+
const int tid = threadIdx.x;
|
230 |
+
const int l_idx = tid / kNThreadsPerC;
|
231 |
+
const int c_idx = tid % kNThreadsPerC;
|
232 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
233 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
234 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
235 |
+
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
236 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
237 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
238 |
+
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
239 |
+
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
240 |
+
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
241 |
+
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
242 |
+
// The last L-chunk will also have enough info to write to final states, since it also contain a few x values
|
243 |
+
// from the previous L-chunk.
|
244 |
+
input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
|
245 |
+
: reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
246 |
+
|
247 |
+
#pragma unroll
|
248 |
+
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
249 |
+
input_t x_vals_load[kNElts] = {0};
|
250 |
+
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
251 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
252 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
253 |
+
}
|
254 |
+
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
255 |
+
}
|
256 |
+
// Load the elements from the previous chunk that are needed for convolution.
|
257 |
+
if (l_idx < kWidth - 1) {
|
258 |
+
input_t x_vals_load[kNElts] = {0};
|
259 |
+
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
260 |
+
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
261 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
262 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
263 |
+
} else if (initial_states != nullptr
|
264 |
+
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
265 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
266 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
267 |
+
}
|
268 |
+
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
269 |
+
}
|
270 |
+
|
271 |
+
__syncthreads();
|
272 |
+
|
273 |
+
if (final_states != nullptr
|
274 |
+
&& l_idx < kWidth - 1
|
275 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
276 |
+
// x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
|
277 |
+
// So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
|
278 |
+
*reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
|
279 |
+
}
|
280 |
+
|
281 |
+
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
282 |
+
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
283 |
+
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
284 |
+
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
285 |
+
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
286 |
+
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
287 |
+
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
288 |
+
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
289 |
+
static_assert(kNThreadsPerRow <= 32);
|
290 |
+
|
291 |
+
const int row_idx = tid / kNThreadsPerRow;
|
292 |
+
const int col_idx = tid % kNThreadsPerRow;
|
293 |
+
|
294 |
+
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
295 |
+
float weight_vals[kWidth] = {0};
|
296 |
+
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
297 |
+
#pragma unroll
|
298 |
+
for (int w = 0; w < kWidth; ++w) {
|
299 |
+
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
300 |
+
}
|
301 |
+
}
|
302 |
+
float x_vals[kWidth - 1 + kLPerThread];
|
303 |
+
#pragma unroll
|
304 |
+
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
305 |
+
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
306 |
+
}
|
307 |
+
int seq_idx_thread[kWidth - 1 + kLPerThread];
|
308 |
+
if constexpr (kHasSeqIdx) {
|
309 |
+
#pragma unroll
|
310 |
+
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
311 |
+
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
312 |
+
}
|
313 |
+
}
|
314 |
+
|
315 |
+
float out_vals[kLPerThread];
|
316 |
+
#pragma unroll
|
317 |
+
for (int i = 0; i < kLPerThread; ++i) {
|
318 |
+
out_vals[i] = bias_val;
|
319 |
+
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
320 |
+
#pragma unroll
|
321 |
+
for (int w = 0; w < kWidth; ++w) {
|
322 |
+
if constexpr (!kHasSeqIdx) {
|
323 |
+
out_vals[i] += weight_vals[w] * x_vals[i + w];
|
324 |
+
} else {
|
325 |
+
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
326 |
+
}
|
327 |
+
}
|
328 |
+
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
|
329 |
+
}
|
330 |
+
|
331 |
+
__syncthreads();
|
332 |
+
#pragma unroll
|
333 |
+
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
|
334 |
+
__syncthreads();
|
335 |
+
|
336 |
+
#pragma unroll
|
337 |
+
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
338 |
+
input_t out_vals_store[kNElts];
|
339 |
+
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
|
340 |
+
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
341 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
342 |
+
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
|
343 |
+
}
|
344 |
+
}
|
345 |
+
|
346 |
+
}
|
347 |
+
|
348 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
349 |
+
void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
350 |
+
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
351 |
+
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
|
352 |
+
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
353 |
+
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
354 |
+
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
355 |
+
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
356 |
+
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
357 |
+
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
358 |
+
dim3 block(Ktraits::kNThreads);
|
359 |
+
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
|
360 |
+
// if (kSmemSize >= 48 * 1024) {
|
361 |
+
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
362 |
+
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
363 |
+
// }
|
364 |
+
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
365 |
+
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
366 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
367 |
+
});
|
368 |
+
}
|
369 |
+
|
370 |
+
template<typename input_t, typename weight_t>
|
371 |
+
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
372 |
+
if (params.width == 2) {
|
373 |
+
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
374 |
+
} else if (params.width == 3) {
|
375 |
+
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
376 |
+
} else if (params.width == 4) {
|
377 |
+
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
378 |
+
}
|
379 |
+
}
|
380 |
+
|
381 |
+
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
382 |
+
template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
383 |
+
template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
384 |
+
template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
385 |
+
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
386 |
+
template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
387 |
+
template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
388 |
+
template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
389 |
+
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
390 |
+
|
391 |
+
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
392 |
+
template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
393 |
+
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
394 |
+
template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
395 |
+
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
396 |
+
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
397 |
+
template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
398 |
+
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
399 |
+
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
causal-conv1d/csrc/causal_conv1d_update.cu
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#include <c10/util/BFloat16.h>
|
6 |
+
#include <c10/util/Half.h>
|
7 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
8 |
+
|
9 |
+
#include "causal_conv1d.h"
|
10 |
+
#include "causal_conv1d_common.h"
|
11 |
+
#include "static_switch.h"
|
12 |
+
|
13 |
+
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
14 |
+
struct Causal_conv1d_update_kernel_traits {
|
15 |
+
using input_t = input_t_;
|
16 |
+
using weight_t = weight_t_;
|
17 |
+
static constexpr int kNThreads = kNThreads_;
|
18 |
+
static constexpr int kWidth = kWidth_;
|
19 |
+
static constexpr int kNBytes = sizeof(input_t);
|
20 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
21 |
+
};
|
22 |
+
|
23 |
+
template<typename Ktraits, bool kIsCircularBuffer>
|
24 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
25 |
+
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
26 |
+
constexpr int kWidth = Ktraits::kWidth;
|
27 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
28 |
+
using input_t = typename Ktraits::input_t;
|
29 |
+
using weight_t = typename Ktraits::weight_t;
|
30 |
+
|
31 |
+
const int tidx = threadIdx.x;
|
32 |
+
const int batch_id = blockIdx.x;
|
33 |
+
const int channel_id = blockIdx.y * kNThreads + tidx;
|
34 |
+
if (channel_id >= params.dim) return;
|
35 |
+
|
36 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
37 |
+
+ channel_id * params.x_c_stride;
|
38 |
+
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
|
39 |
+
+ channel_id * params.conv_state_c_stride;
|
40 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
41 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
42 |
+
+ channel_id * params.out_c_stride;
|
43 |
+
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
44 |
+
|
45 |
+
int state_len = params.conv_state_len;
|
46 |
+
int advance_len = params.seqlen;
|
47 |
+
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
48 |
+
int update_idx = cache_seqlen - (kWidth - 1);
|
49 |
+
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
50 |
+
|
51 |
+
float weight_vals[kWidth] = {0};
|
52 |
+
#pragma unroll
|
53 |
+
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
54 |
+
|
55 |
+
float x_vals[kWidth] = {0};
|
56 |
+
if constexpr (!kIsCircularBuffer) {
|
57 |
+
#pragma unroll 2
|
58 |
+
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
59 |
+
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
60 |
+
}
|
61 |
+
#pragma unroll
|
62 |
+
for (int i = 0; i < kWidth - 1; ++i) {
|
63 |
+
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
64 |
+
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
65 |
+
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
66 |
+
}
|
67 |
+
x_vals[i] = float(state_val);
|
68 |
+
}
|
69 |
+
} else {
|
70 |
+
#pragma unroll
|
71 |
+
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
72 |
+
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
73 |
+
x_vals[i] = float(state_val);
|
74 |
+
}
|
75 |
+
}
|
76 |
+
#pragma unroll 2
|
77 |
+
for (int i = 0; i < params.seqlen; ++i) {
|
78 |
+
input_t x_val = x[i * params.x_l_stride];
|
79 |
+
if constexpr (!kIsCircularBuffer) {
|
80 |
+
if (i < advance_len && state_len - advance_len + i >= 0) {
|
81 |
+
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
82 |
+
}
|
83 |
+
} else {
|
84 |
+
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
85 |
+
++update_idx;
|
86 |
+
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
87 |
+
}
|
88 |
+
x_vals[kWidth - 1] = float(x_val);
|
89 |
+
float out_val = bias_val;
|
90 |
+
#pragma unroll
|
91 |
+
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
92 |
+
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
93 |
+
out[i * params.out_l_stride] = input_t(out_val);
|
94 |
+
// Shift the input buffer by 1
|
95 |
+
#pragma unroll
|
96 |
+
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
97 |
+
}
|
98 |
+
}
|
99 |
+
|
100 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
101 |
+
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
102 |
+
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
103 |
+
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
104 |
+
auto kernel = params.cache_seqlens == nullptr
|
105 |
+
? &causal_conv1d_update_kernel<Ktraits, false>
|
106 |
+
: &causal_conv1d_update_kernel<Ktraits, true>;
|
107 |
+
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
108 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
109 |
+
}
|
110 |
+
|
111 |
+
template<typename input_t, typename weight_t>
|
112 |
+
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
113 |
+
if (params.width == 2) {
|
114 |
+
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
115 |
+
} else if (params.width == 3) {
|
116 |
+
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
117 |
+
} else if (params.width == 4) {
|
118 |
+
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
119 |
+
}
|
120 |
+
}
|
121 |
+
|
122 |
+
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
123 |
+
template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
124 |
+
template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
125 |
+
template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
126 |
+
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
127 |
+
template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
128 |
+
template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
129 |
+
template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
130 |
+
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
causal-conv1d/csrc/static_switch.h
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
2 |
+
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
3 |
+
|
4 |
+
#pragma once
|
5 |
+
|
6 |
+
/// @param COND - a boolean expression to switch by
|
7 |
+
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
8 |
+
/// @param ... - code to execute for true and false
|
9 |
+
///
|
10 |
+
/// Usage:
|
11 |
+
/// ```
|
12 |
+
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
13 |
+
/// some_function<BoolConst>(...);
|
14 |
+
/// });
|
15 |
+
/// ```
|
16 |
+
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
17 |
+
[&] { \
|
18 |
+
if (COND) { \
|
19 |
+
static constexpr bool CONST_NAME = true; \
|
20 |
+
return __VA_ARGS__(); \
|
21 |
+
} else { \
|
22 |
+
static constexpr bool CONST_NAME = false; \
|
23 |
+
return __VA_ARGS__(); \
|
24 |
+
} \
|
25 |
+
}()
|
causal-conv1d/dist/causal_conv1d-1.4.0-py3.9.egg
ADDED
Binary file (10 kB). View file
|
|
causal-conv1d/rocm_patch/rocm6_0.patch
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--- /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h 2023-12-12 20:11:48.000000000 +0000
|
2 |
+
+++ rocm_update_files/amd_hip_bf16.h 2024-05-20 17:40:26.983349079 +0000
|
3 |
+
@@ -137,7 +137,7 @@
|
4 |
+
* \ingroup HIP_INTRINSIC_BFLOAT16_CONV
|
5 |
+
* \brief Converts float to bfloat16
|
6 |
+
*/
|
7 |
+
-__HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) {
|
8 |
+
+__HOST_DEVICE__ static inline __hip_bfloat16 __float2bfloat16(float f) {
|
9 |
+
__hip_bfloat16 ret;
|
10 |
+
union {
|
11 |
+
float fp32;
|
12 |
+
@@ -181,7 +181,7 @@
|
13 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
14 |
+
* \brief Converts and moves bfloat162 to float2
|
15 |
+
*/
|
16 |
+
-__HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) {
|
17 |
+
+__HOST_DEVICE__ static inline float2 __bfloat1622float2(const __hip_bfloat162 a) {
|
18 |
+
return float2{__bfloat162float(a.x), __bfloat162float(a.y)};
|
19 |
+
}
|
20 |
+
|
21 |
+
@@ -209,7 +209,7 @@
|
22 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
23 |
+
* \brief Convert double to __hip_bfloat16
|
24 |
+
*/
|
25 |
+
-__HOST_DEVICE__ __hip_bfloat16 __double2bfloat16(const double a) {
|
26 |
+
+__HOST_DEVICE__ static inline __hip_bfloat16 __double2bfloat16(const double a) {
|
27 |
+
return __float2bfloat16((float)a);
|
28 |
+
}
|
29 |
+
|
30 |
+
@@ -217,7 +217,7 @@
|
31 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
32 |
+
* \brief Convert float2 to __hip_bfloat162
|
33 |
+
*/
|
34 |
+
-__HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
|
35 |
+
+__HOST_DEVICE__ static inline __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
|
36 |
+
return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)};
|
37 |
+
}
|
38 |
+
|
39 |
+
@@ -247,7 +247,7 @@
|
40 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
41 |
+
* \brief Converts high 16 bits of __hip_bfloat162 to float and returns the result
|
42 |
+
*/
|
43 |
+
-__HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }
|
44 |
+
+__HOST_DEVICE__ static inline float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }
|
45 |
+
|
46 |
+
/**
|
47 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
48 |
+
@@ -275,7 +275,7 @@
|
49 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
50 |
+
* \brief Converts low 16 bits of __hip_bfloat162 to float and returns the result
|
51 |
+
*/
|
52 |
+
-__HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }
|
53 |
+
+__HOST_DEVICE__ static inline float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }
|
54 |
+
|
55 |
+
/**
|
56 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
causal-conv1d/setup.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import warnings
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import shutil
|
8 |
+
import ast
|
9 |
+
from pathlib import Path
|
10 |
+
from packaging.version import parse, Version
|
11 |
+
import platform
|
12 |
+
|
13 |
+
from setuptools import setup, find_packages
|
14 |
+
import subprocess
|
15 |
+
|
16 |
+
import urllib.request
|
17 |
+
import urllib.error
|
18 |
+
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, HIP_HOME
|
22 |
+
|
23 |
+
|
24 |
+
with open("README.md", "r", encoding="utf-8") as fh:
|
25 |
+
long_description = fh.read()
|
26 |
+
|
27 |
+
|
28 |
+
# ninja build does not work unless include_dirs are abs path
|
29 |
+
this_dir = os.path.dirname(os.path.abspath(__file__))
|
30 |
+
|
31 |
+
PACKAGE_NAME = "causal_conv1d"
|
32 |
+
|
33 |
+
BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}"
|
34 |
+
|
35 |
+
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
|
36 |
+
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
|
37 |
+
FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
|
38 |
+
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
|
39 |
+
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
|
40 |
+
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
|
41 |
+
|
42 |
+
|
43 |
+
def get_platform():
|
44 |
+
"""
|
45 |
+
Returns the platform name as used in wheel filenames.
|
46 |
+
"""
|
47 |
+
if sys.platform.startswith("linux"):
|
48 |
+
return "linux_x86_64"
|
49 |
+
elif sys.platform == "darwin":
|
50 |
+
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
|
51 |
+
return f"macosx_{mac_version}_x86_64"
|
52 |
+
elif sys.platform == "win32":
|
53 |
+
return "win_amd64"
|
54 |
+
else:
|
55 |
+
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
def get_hip_version(rocm_dir):
|
61 |
+
|
62 |
+
hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc")
|
63 |
+
try:
|
64 |
+
raw_output = subprocess.check_output(
|
65 |
+
[hipcc_bin, "--version"], universal_newlines=True
|
66 |
+
)
|
67 |
+
except Exception as e:
|
68 |
+
print(
|
69 |
+
f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}"
|
70 |
+
)
|
71 |
+
return None, None
|
72 |
+
|
73 |
+
for line in raw_output.split("\n"):
|
74 |
+
if "HIP version" in line:
|
75 |
+
rocm_version = parse(line.split()[-1].replace("-", "+")) # local version is not parsed correctly
|
76 |
+
return line, rocm_version
|
77 |
+
|
78 |
+
return None, None
|
79 |
+
|
80 |
+
|
81 |
+
def get_torch_hip_version():
|
82 |
+
if torch.version.hip:
|
83 |
+
return parse(torch.version.hip.split()[-1].replace("-", "+"))
|
84 |
+
else:
|
85 |
+
return None
|
86 |
+
|
87 |
+
|
88 |
+
def check_if_hip_home_none(global_option: str) -> None:
|
89 |
+
|
90 |
+
if HIP_HOME is not None:
|
91 |
+
return
|
92 |
+
# warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
|
93 |
+
# in that case.
|
94 |
+
warnings.warn(
|
95 |
+
f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?"
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def check_if_cuda_home_none(global_option: str) -> None:
|
100 |
+
if CUDA_HOME is not None:
|
101 |
+
return
|
102 |
+
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
|
103 |
+
# in that case.
|
104 |
+
warnings.warn(
|
105 |
+
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
|
106 |
+
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
|
107 |
+
"only images whose names contain 'devel' will provide nvcc."
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
def append_nvcc_threads(nvcc_extra_args):
|
112 |
+
return nvcc_extra_args + ["--threads", "4"]
|
113 |
+
|
114 |
+
|
115 |
+
cmdclass = {}
|
116 |
+
ext_modules = []
|
117 |
+
|
118 |
+
|
119 |
+
HIP_BUILD = bool(torch.version.hip)
|
120 |
+
|
121 |
+
if not SKIP_CUDA_BUILD:
|
122 |
+
|
123 |
+
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
124 |
+
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
125 |
+
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
126 |
+
|
127 |
+
|
128 |
+
cc_flag = []
|
129 |
+
|
130 |
+
if HIP_BUILD:
|
131 |
+
check_if_hip_home_none(PACKAGE_NAME)
|
132 |
+
|
133 |
+
rocm_home = os.getenv("ROCM_PATH")
|
134 |
+
_, hip_version = get_hip_version(rocm_home)
|
135 |
+
|
136 |
+
|
137 |
+
if HIP_HOME is not None:
|
138 |
+
if hip_version < Version("6.0"):
|
139 |
+
raise RuntimeError(
|
140 |
+
f"{PACKAGE_NAME} is only supported on ROCm 6.0 and above. "
|
141 |
+
"Note: make sure HIP has a supported version by running hipcc --version."
|
142 |
+
)
|
143 |
+
if hip_version == Version("6.0"):
|
144 |
+
warnings.warn(
|
145 |
+
f"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. "
|
146 |
+
"Refer to the README.md for detailed instructions.",
|
147 |
+
UserWarning
|
148 |
+
)
|
149 |
+
|
150 |
+
cc_flag.append("-DBUILD_PYTHON_PACKAGE")
|
151 |
+
|
152 |
+
else:
|
153 |
+
cc_flag.append("-gencode")
|
154 |
+
cc_flag.append("arch=compute_53,code=sm_53")
|
155 |
+
cc_flag.append("-gencode")
|
156 |
+
cc_flag.append("arch=compute_62,code=sm_62")
|
157 |
+
cc_flag.append("-gencode")
|
158 |
+
cc_flag.append("arch=compute_70,code=sm_70")
|
159 |
+
cc_flag.append("-gencode")
|
160 |
+
cc_flag.append("arch=compute_72,code=sm_72")
|
161 |
+
cc_flag.append("-gencode")
|
162 |
+
cc_flag.append("arch=compute_80,code=sm_80")
|
163 |
+
cc_flag.append("-gencode")
|
164 |
+
cc_flag.append("arch=compute_87,code=sm_87")
|
165 |
+
|
166 |
+
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
|
167 |
+
# torch._C._GLIBCXX_USE_CXX11_ABI
|
168 |
+
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
|
169 |
+
if FORCE_CXX11_ABI:
|
170 |
+
torch._C._GLIBCXX_USE_CXX11_ABI = True
|
171 |
+
|
172 |
+
|
173 |
+
if HIP_BUILD:
|
174 |
+
extra_compile_args = {
|
175 |
+
"cxx": ["-O3", "-std=c++17"],
|
176 |
+
}
|
177 |
+
else:
|
178 |
+
extra_compile_args = {
|
179 |
+
"cxx": ["-O3"],
|
180 |
+
}
|
181 |
+
|
182 |
+
|
183 |
+
def get_package_version():
|
184 |
+
with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f:
|
185 |
+
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
|
186 |
+
public_version = ast.literal_eval(version_match.group(1))
|
187 |
+
local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION")
|
188 |
+
if local_version:
|
189 |
+
return f"{public_version}+{local_version}"
|
190 |
+
else:
|
191 |
+
return str(public_version)
|
192 |
+
|
193 |
+
|
194 |
+
def get_wheel_url():
|
195 |
+
|
196 |
+
# Determine the version numbers that will be used to determine the correct wheel
|
197 |
+
torch_version_raw = parse(torch.__version__)
|
198 |
+
|
199 |
+
if HIP_BUILD:
|
200 |
+
# We're using the HIP version used to build torch, not the one currently installed
|
201 |
+
torch_hip_version = get_torch_hip_version()
|
202 |
+
hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
|
203 |
+
|
204 |
+
gpu_compute_version = hip_version if HIP_BUILD else cuda_version
|
205 |
+
cuda_or_hip = "hip" if HIP_BUILD else "cu"
|
206 |
+
|
207 |
+
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
208 |
+
platform_name = get_platform()
|
209 |
+
causal_conv1d_version = get_package_version()
|
210 |
+
|
211 |
+
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
|
212 |
+
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
|
213 |
+
|
214 |
+
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
215 |
+
wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
|
216 |
+
|
217 |
+
wheel_url = BASE_WHEEL_URL.format(
|
218 |
+
tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
|
219 |
+
)
|
220 |
+
return wheel_url, wheel_filename
|
221 |
+
|
222 |
+
|
223 |
+
class CachedWheelsCommand(_bdist_wheel):
|
224 |
+
"""
|
225 |
+
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
|
226 |
+
find an existing wheel (which is currently the case for all installs). We use
|
227 |
+
the environment parameters to detect whether there is already a pre-built version of a compatible
|
228 |
+
wheel available and short-circuits the standard full build pipeline.
|
229 |
+
"""
|
230 |
+
|
231 |
+
def run(self):
|
232 |
+
if FORCE_BUILD:
|
233 |
+
return super().run()
|
234 |
+
|
235 |
+
wheel_url, wheel_filename = get_wheel_url()
|
236 |
+
print("Guessing wheel URL: ", wheel_url)
|
237 |
+
try:
|
238 |
+
urllib.request.urlretrieve(wheel_url, wheel_filename)
|
239 |
+
|
240 |
+
# Make the archive
|
241 |
+
# Lifted from the root wheel processing command
|
242 |
+
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
|
243 |
+
if not os.path.exists(self.dist_dir):
|
244 |
+
os.makedirs(self.dist_dir)
|
245 |
+
|
246 |
+
impl_tag, abi_tag, plat_tag = self.get_tag()
|
247 |
+
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
|
248 |
+
|
249 |
+
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
|
250 |
+
print("Raw wheel path", wheel_path)
|
251 |
+
shutil.move(wheel_filename, wheel_path)
|
252 |
+
except urllib.error.HTTPError:
|
253 |
+
print("Precompiled wheel not found. Building from source...")
|
254 |
+
# If the wheel could not be downloaded, build from source
|
255 |
+
super().run()
|
256 |
+
|
257 |
+
|
258 |
+
setup(
|
259 |
+
name=PACKAGE_NAME,
|
260 |
+
version=get_package_version(),
|
261 |
+
packages=find_packages(
|
262 |
+
exclude=(
|
263 |
+
"build",
|
264 |
+
"csrc",
|
265 |
+
"include",
|
266 |
+
"tests",
|
267 |
+
"dist",
|
268 |
+
"docs",
|
269 |
+
"benchmarks",
|
270 |
+
"causal_conv1d.egg-info",
|
271 |
+
)
|
272 |
+
),
|
273 |
+
author="Tri Dao",
|
274 |
+
author_email="[email protected]",
|
275 |
+
description="Causal depthwise conv1d in CUDA, with a PyTorch interface",
|
276 |
+
long_description=long_description,
|
277 |
+
long_description_content_type="text/markdown",
|
278 |
+
url="https://github.com/Dao-AILab/causal-conv1d",
|
279 |
+
classifiers=[
|
280 |
+
"Programming Language :: Python :: 3",
|
281 |
+
"License :: OSI Approved :: BSD License",
|
282 |
+
"Operating System :: Unix",
|
283 |
+
],
|
284 |
+
ext_modules=ext_modules,
|
285 |
+
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
|
286 |
+
if ext_modules
|
287 |
+
else {
|
288 |
+
"bdist_wheel": CachedWheelsCommand,
|
289 |
+
},
|
290 |
+
python_requires=">=3.8",
|
291 |
+
install_requires=[
|
292 |
+
"torch",
|
293 |
+
"packaging",
|
294 |
+
"ninja",
|
295 |
+
],
|
296 |
+
)
|
causal-conv1d/tests/test_causal_conv1d.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024, Tri Dao.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
import pytest
|
9 |
+
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref
|
13 |
+
from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref
|
14 |
+
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states, causal_conv1d_varlen_states_ref
|
15 |
+
|
16 |
+
|
17 |
+
@pytest.mark.parametrize("return_final_states", [False, True])
|
18 |
+
# @pytest.mark.parametrize("return_final_states", [True])
|
19 |
+
@pytest.mark.parametrize("has_initial_states", [False, True])
|
20 |
+
# @pytest.mark.parametrize("has_initial_states", [False])
|
21 |
+
@pytest.mark.parametrize("channel_last", [False, True])
|
22 |
+
# @pytest.mark.parametrize('channel_last', [True])
|
23 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
24 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
25 |
+
@pytest.mark.parametrize("silu_activation", [False, True])
|
26 |
+
# @pytest.mark.parametrize('silu_activation', [True])
|
27 |
+
@pytest.mark.parametrize("has_bias", [False, True])
|
28 |
+
# @pytest.mark.parametrize('has_bias', [True])
|
29 |
+
@pytest.mark.parametrize("width", [2, 3, 4])
|
30 |
+
# @pytest.mark.parametrize('width', [3])
|
31 |
+
@pytest.mark.parametrize(
|
32 |
+
"seqlen", [1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
33 |
+
)
|
34 |
+
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
35 |
+
# @pytest.mark.parametrize('seqlen', [128])
|
36 |
+
@pytest.mark.parametrize('dim', [64, 4096 + 32])
|
37 |
+
# @pytest.mark.parametrize('dim', [64])
|
38 |
+
def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states):
|
39 |
+
if not channel_last and (has_initial_states or return_final_states):
|
40 |
+
pytest.skip("Only channel_last support initial_states or return_final_states")
|
41 |
+
device = "cuda"
|
42 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
43 |
+
if itype == torch.bfloat16:
|
44 |
+
rtol, atol = 1e-2, 5e-2
|
45 |
+
rtolw, atolw = (1e-3, 1e-3)
|
46 |
+
# set seed
|
47 |
+
torch.random.manual_seed(0)
|
48 |
+
batch = 2
|
49 |
+
# batch = 1
|
50 |
+
if not channel_last:
|
51 |
+
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
|
52 |
+
else:
|
53 |
+
x = rearrange(
|
54 |
+
torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
55 |
+
).requires_grad_()
|
56 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
57 |
+
if has_bias:
|
58 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
59 |
+
else:
|
60 |
+
bias = None
|
61 |
+
if has_initial_states:
|
62 |
+
initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_()
|
63 |
+
else:
|
64 |
+
initial_states = None
|
65 |
+
x_ref = x.detach().clone().requires_grad_()
|
66 |
+
weight_ref = weight.detach().clone().requires_grad_()
|
67 |
+
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
|
68 |
+
initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None
|
69 |
+
activation = None if not silu_activation else "silu"
|
70 |
+
out = causal_conv1d_fn(x, weight, bias, initial_states=initial_states, return_final_states=return_final_states,
|
71 |
+
activation=activation)
|
72 |
+
out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation)
|
73 |
+
if return_final_states:
|
74 |
+
out, final_states = out
|
75 |
+
out_ref, final_states_ref = out_ref
|
76 |
+
print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}")
|
77 |
+
print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}")
|
78 |
+
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
|
79 |
+
|
80 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
81 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
82 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
83 |
+
|
84 |
+
if return_final_states:
|
85 |
+
out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
|
86 |
+
out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
|
87 |
+
|
88 |
+
g = torch.randn_like(out)
|
89 |
+
out.backward(g)
|
90 |
+
out_ref.backward(g)
|
91 |
+
|
92 |
+
print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
|
93 |
+
print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
|
94 |
+
if has_bias:
|
95 |
+
print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
|
96 |
+
if has_initial_states:
|
97 |
+
print(f"dinitial_states max diff: {(initial_states.grad - initial_states_ref.grad).abs().max().item()}")
|
98 |
+
|
99 |
+
assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
100 |
+
assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
|
101 |
+
if has_bias:
|
102 |
+
assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
|
103 |
+
if has_initial_states:
|
104 |
+
assert torch.allclose(initial_states.grad, initial_states_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
105 |
+
|
106 |
+
|
107 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
108 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
109 |
+
@pytest.mark.parametrize("silu_activation", [False, True])
|
110 |
+
# @pytest.mark.parametrize('silu_activation', [True])
|
111 |
+
@pytest.mark.parametrize("has_bias", [False, True])
|
112 |
+
# @pytest.mark.parametrize('has_bias', [True])
|
113 |
+
@pytest.mark.parametrize("has_cache_seqlens", [False, True])
|
114 |
+
# @pytest.mark.parametrize('has_cache_seqlens', [True])
|
115 |
+
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
116 |
+
# @pytest.mark.parametrize('seqlen', [4])
|
117 |
+
@pytest.mark.parametrize("width", [2, 3, 4])
|
118 |
+
# @pytest.mark.parametrize('width', [4])
|
119 |
+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
120 |
+
# @pytest.mark.parametrize("dim", [2048])
|
121 |
+
def test_causal_conv1d_update(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype):
|
122 |
+
device = "cuda"
|
123 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
124 |
+
if itype == torch.bfloat16:
|
125 |
+
rtol, atol = 1e-2, 5e-2
|
126 |
+
rtolw, atolw = (1e-3, 1e-3)
|
127 |
+
# set seed
|
128 |
+
torch.random.manual_seed(0)
|
129 |
+
batch = 64
|
130 |
+
# batch = 1
|
131 |
+
# dim = 64
|
132 |
+
x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2)
|
133 |
+
state_len = torch.randint(width - 1, width + 10, (1,)).item()
|
134 |
+
conv_state = torch.randn(batch, state_len, dim, device=device, dtype=itype).transpose(-1, -2)
|
135 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
136 |
+
if has_bias:
|
137 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
138 |
+
else:
|
139 |
+
bias = None
|
140 |
+
conv_state_ref = conv_state.detach().clone()
|
141 |
+
activation = None if not silu_activation else "silu"
|
142 |
+
cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device)
|
143 |
+
if has_cache_seqlens else None)
|
144 |
+
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
|
145 |
+
out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
|
146 |
+
|
147 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
148 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
149 |
+
assert torch.equal(conv_state, conv_state_ref)
|
150 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
151 |
+
|
152 |
+
|
153 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
154 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
155 |
+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
156 |
+
# @pytest.mark.parametrize("dim", [2048])
|
157 |
+
def test_causal_conv1d_get_states(dim, itype):
|
158 |
+
device = "cuda"
|
159 |
+
# set seed
|
160 |
+
torch.random.manual_seed(0)
|
161 |
+
seqlens = torch.randint(1, 32, (100,), device=device)
|
162 |
+
total_seqlen = seqlens.sum().item()
|
163 |
+
x = torch.randn(total_seqlen, dim, device=device, dtype=itype)
|
164 |
+
cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0))
|
165 |
+
state_len = 20
|
166 |
+
out = causal_conv1d_varlen_states(x, cu_seqlens, state_len)
|
167 |
+
out_ref = causal_conv1d_varlen_states_ref(x, cu_seqlens, state_len)
|
168 |
+
assert torch.equal(out, out_ref)
|
169 |
+
|
170 |
+
|
171 |
+
# @pytest.mark.parametrize("channel_last", [False, True])
|
172 |
+
@pytest.mark.parametrize('channel_last', [True])
|
173 |
+
# @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
174 |
+
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
175 |
+
# @pytest.mark.parametrize("silu_activation", [False, True])
|
176 |
+
@pytest.mark.parametrize('silu_activation', [True])
|
177 |
+
# @pytest.mark.parametrize("has_bias", [False, True])
|
178 |
+
@pytest.mark.parametrize('has_bias', [True])
|
179 |
+
# @pytest.mark.parametrize("width", [2, 3, 4])
|
180 |
+
@pytest.mark.parametrize('width', [4])
|
181 |
+
@pytest.mark.parametrize(
|
182 |
+
# "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
183 |
+
"seqlen", [2048]
|
184 |
+
)
|
185 |
+
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
186 |
+
# @pytest.mark.parametrize('seqlen', [128])
|
187 |
+
def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
|
188 |
+
device = "cuda"
|
189 |
+
# set seed
|
190 |
+
torch.random.manual_seed(0)
|
191 |
+
batch = 2
|
192 |
+
# batch = 1
|
193 |
+
dim = 4096 + 32 # Try dim not divisible by 64
|
194 |
+
# dim = 64
|
195 |
+
if not channel_last:
|
196 |
+
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
|
197 |
+
else:
|
198 |
+
x = rearrange(
|
199 |
+
torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
200 |
+
).requires_grad_()
|
201 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
202 |
+
if has_bias:
|
203 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
204 |
+
else:
|
205 |
+
bias = None
|
206 |
+
activation = None if not silu_activation else "silu"
|
207 |
+
out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
|
208 |
+
g = torch.randn_like(out0)
|
209 |
+
dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
|
210 |
+
dw_atol = 1e-4
|
211 |
+
db_atol = 1e-4
|
212 |
+
|
213 |
+
for i in range(10000):
|
214 |
+
out = causal_conv1d_fn(x, weight, bias, activation=activation)
|
215 |
+
dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
|
216 |
+
dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
|
217 |
+
# if not dw_equal:
|
218 |
+
# breakpoint()
|
219 |
+
if has_bias:
|
220 |
+
db_equal = torch.allclose(db, db0, atol=db_atol)
|
221 |
+
# if not db_equal:
|
222 |
+
# breakpoint()
|
223 |
+
assert torch.equal(out, out0)
|
224 |
+
assert torch.equal(dx, dx0)
|
225 |
+
assert dw_equal
|
226 |
+
if has_bias:
|
227 |
+
assert dw_equal
|
228 |
+
|
229 |
+
|
230 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
231 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
232 |
+
@pytest.mark.parametrize("silu_activation", [False, True])
|
233 |
+
# @pytest.mark.parametrize('silu_activation', [False])
|
234 |
+
@pytest.mark.parametrize("has_bias", [False, True])
|
235 |
+
# @pytest.mark.parametrize('has_bias', [False])
|
236 |
+
@pytest.mark.parametrize("width", [2, 3, 4])
|
237 |
+
# @pytest.mark.parametrize('width', [2])
|
238 |
+
@pytest.mark.parametrize(
|
239 |
+
"seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
240 |
+
)
|
241 |
+
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
242 |
+
# @pytest.mark.parametrize('seqlen', [2048])
|
243 |
+
@pytest.mark.parametrize('dim', [64, 4096 + 32])
|
244 |
+
# @pytest.mark.parametrize('dim', [64])
|
245 |
+
def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, itype):
|
246 |
+
device = "cuda"
|
247 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
248 |
+
if itype == torch.bfloat16:
|
249 |
+
rtol, atol = 1e-2, 5e-2
|
250 |
+
rtolw, atolw = (1e-3, 1e-3)
|
251 |
+
# set seed
|
252 |
+
torch.random.manual_seed(seqlen + dim + width)
|
253 |
+
batch = 3
|
254 |
+
seqlens = []
|
255 |
+
for b in range(batch):
|
256 |
+
nsplits = torch.randint(1, 5, (1,)).item()
|
257 |
+
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
258 |
+
seqlens.append(torch.diff(torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist())
|
259 |
+
assert sum(seqlens[-1]) == seqlen
|
260 |
+
assert all(s > 0 for s in seqlens[-1])
|
261 |
+
# Only support channel_last
|
262 |
+
x = rearrange(
|
263 |
+
torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
264 |
+
).requires_grad_()
|
265 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
266 |
+
if has_bias:
|
267 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
268 |
+
else:
|
269 |
+
bias = None
|
270 |
+
seq_idx = torch.stack([torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(sl)], dim=0)
|
271 |
+
for sl in seqlens], dim=0)
|
272 |
+
x_ref = x.detach().clone().requires_grad_()
|
273 |
+
weight_ref = weight.detach().clone().requires_grad_()
|
274 |
+
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
|
275 |
+
activation = None if not silu_activation else "silu"
|
276 |
+
out = causal_conv1d_fn(x, weight, bias, seq_idx=seq_idx, activation=activation)
|
277 |
+
out_ref = []
|
278 |
+
for b in range(batch):
|
279 |
+
out_ref_b = []
|
280 |
+
for x_s in torch.split(x_ref[[b]], seqlens[b], dim=2):
|
281 |
+
out_ref_b.append(causal_conv1d_ref(x_s, weight_ref, bias_ref, activation=activation))
|
282 |
+
out_ref.append(torch.cat(out_ref_b, dim=2))
|
283 |
+
out_ref = torch.cat(out_ref, dim=0)
|
284 |
+
|
285 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
286 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
287 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
288 |
+
|
289 |
+
g = torch.randn_like(out)
|
290 |
+
out_ref.backward(g)
|
291 |
+
out.backward(g)
|
292 |
+
|
293 |
+
print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
|
294 |
+
print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
|
295 |
+
if has_bias:
|
296 |
+
print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
|
297 |
+
|
298 |
+
assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
299 |
+
assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
|
300 |
+
if has_bias:
|
301 |
+
assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
|