Somunia's picture
Upload 28 files
8b19012 verified
raw
history blame
1.57 kB
Metadata-Version: 2.1
Name: causal-conv1d
Version: 1.4.0
Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
Home-page: https://github.com/Dao-AILab/causal-conv1d
Author: Tri Dao
Author-email: [email protected]
License: UNKNOWN
Platform: UNKNOWN
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: Unix
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
License-File: AUTHORS
# Causal depthwise conv1d in CUDA with a PyTorch interface
Features:
- Support fp32, fp16, bf16.
- Kernel size 2, 3, 4.
## How to use
```
from causal_conv1d import causal_conv1d_fn
```
```
def causal_conv1d_fn(x, weight, bias=None, activation=None):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
```
Equivalent to:
```
import torch.nn.functional as F
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
```
## Additional Prerequisites for AMD cards
### Patching ROCm
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
```bash
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
```