Spaces:
Sleeping
Sleeping
File size: 2,491 Bytes
8b19012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#ifndef USE_ROCM
#include <cuda_bf16.h>
template<typename T>
__device__ inline T shuffle_xor(T val, int offset) {
return __shfl_xor_sync(uint32_t(-1), val, offset);
}
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return std::max(ilist);
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return std::min(a, b);
}
#else
#include <hip/hip_bf16.h>
template<typename T>
__device__ inline T shuffle_xor(T val, int offset) {
return __shfl_xor(val, offset);
}
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return *std::max_element(ilist.begin(), ilist.end());
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return a < b ? a : b;
}
#endif
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES> struct BytesToType {};
template<> struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<> struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<> struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<> struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<> struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, shuffle_xor(x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, shuffle_xor(x, 1));
return x;
}
};
|