/****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #ifndef USE_ROCM #include #else #include #endif #include #include // For scalar_value_type #ifndef USE_ROCM constexpr size_t custom_max(std::initializer_list ilist) { return std::max(ilist); } template constexpr T constexpr_min(T a, T b) { return std::min(a, b); } #else constexpr size_t custom_max(std::initializer_list ilist) { return *std::max_element(ilist.begin(), ilist.end()); } template constexpr T constexpr_min(T a, T b) { return a < b ? a : b; } #endif #define MAX_DSTATE 256 using complex_t = c10::complex; inline __device__ float2 operator+(const float2 & a, const float2 & b){ return {a.x + b.x, a.y + b.y}; } inline __device__ float3 operator+(const float3 &a, const float3 &b) { return {a.x + b.x, a.y + b.y, a.z + b.z}; } inline __device__ float4 operator+(const float4 & a, const float4 & b){ return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BytesToType {}; template<> struct BytesToType<16> { using Type = uint4; static_assert(sizeof(Type) == 16); }; template<> struct BytesToType<8> { using Type = uint64_t; static_assert(sizeof(Type) == 8); }; template<> struct BytesToType<4> { using Type = uint32_t; static_assert(sizeof(Type) == 4); }; template<> struct BytesToType<2> { using Type = uint16_t; static_assert(sizeof(Type) == 2); }; template<> struct BytesToType<1> { using Type = uint8_t; static_assert(sizeof(Type) == 1); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Converter{ static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = src[i]; } } }; template struct Converter{ static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { static_assert(N % 2 == 0); auto &src2 = reinterpret_cast(src); auto &dst2 = reinterpret_cast(dst); #pragma unroll for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } } }; #if __CUDA_ARCH__ >= 800 template struct Converter{ static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { static_assert(N % 2 == 0); auto &src2 = reinterpret_cast(src); auto &dst2 = reinterpret_cast(dst); #pragma unroll for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } } }; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 __device__ __forceinline__ complex_t cexp2f(complex_t z) { float t = exp2f(z.real_); float c, s; sincosf(z.imag_, &s, &c); return complex_t(c * t, s * t); } __device__ __forceinline__ complex_t cexpf(complex_t z) { float t = expf(z.real_); float c, s; sincosf(z.imag_, &s, &c); return complex_t(c * t, s * t); } template struct SSMScanOp; template<> struct SSMScanOp { __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); } }; template<> struct SSMScanOp { __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { complex_t a0 = complex_t(ab0.x, ab0.y); complex_t b0 = complex_t(ab0.z, ab0.w); complex_t a1 = complex_t(ab1.x, ab1.y); complex_t b1 = complex_t(ab1.z, ab1.w); complex_t out_a = a1 * a0; complex_t out_b = a1 * b0 + b1; return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); } }; // A stateful callback functor that maintains a running prefix to be applied // during consecutive scan operations. template struct SSMScanPrefixCallbackOp { using scan_t = std::conditional_t, float2, float4>; scan_t running_prefix; // Constructor __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} // Callback operator to be entered by the first warp of threads in the block. // Thread-0 is responsible for returning a value for seeding the block-wide scan. __device__ scan_t operator()(scan_t block_aggregate) { scan_t old_prefix = running_prefix; running_prefix = SSMScanOp()(running_prefix, block_aggregate); return old_prefix; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void load_input(typename Ktraits::input_t *u, typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], typename Ktraits::BlockLoadT::TempStorage &smem_load, int seqlen) { if constexpr (Ktraits::kIsEvenLen) { auto& smem_load_vec = reinterpret_cast(smem_load); using vec_t = typename Ktraits::vec_t; typename Ktraits::BlockLoadVecT(smem_load_vec).Load( reinterpret_cast(u), reinterpret_cast(u_vals) #ifdef USE_ROCM , Ktraits::kNThreads * Ktraits::kNLoads #endif ); } else { typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); } } template inline __device__ void load_weight(typename Ktraits::input_t *Bvar, typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, int seqlen) { constexpr int kNItems = Ktraits::kNItems; if constexpr (!Ktraits::kIsComplex) { typename Ktraits::input_t B_vals_load[kNItems]; if constexpr (Ktraits::kIsEvenLen) { auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); using vec_t = typename Ktraits::vec_t; typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( reinterpret_cast(Bvar), reinterpret_cast(B_vals_load) ); } else { typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); } // #pragma unroll // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } Converter::to_float(B_vals_load, B_vals); } else { typename Ktraits::input_t B_vals_load[kNItems * 2]; if constexpr (Ktraits::kIsEvenLen) { auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); using vec_t = typename Ktraits::vec_t; typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( reinterpret_cast(Bvar), reinterpret_cast(B_vals_load) ); } else { typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); } #pragma unroll for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } } } template inline __device__ void store_output(typename Ktraits::input_t *out, const float (&out_vals)[Ktraits::kNItems], typename Ktraits::BlockStoreT::TempStorage &smem_store, int seqlen) { typename Ktraits::input_t write_vals[Ktraits::kNItems]; #pragma unroll for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } if constexpr (Ktraits::kIsEvenLen) { auto& smem_store_vec = reinterpret_cast(smem_store); using vec_t = typename Ktraits::vec_t; typename Ktraits::BlockStoreVecT(smem_store_vec).Store( reinterpret_cast(out), reinterpret_cast(write_vals) ); } else { typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); } }