|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_ |
|
#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_ |
|
|
|
#include <algorithm> |
|
#include <cstdint> |
|
#include <iostream> |
|
#include <memory> |
|
#include <tuple> |
|
#include <vector> |
|
|
|
#include "glog/logging.h" |
|
|
|
#include "sparse_matmul/compute/kernels_generic.h" |
|
#include "sparse_matmul/compute/matmul.h" |
|
#include "sparse_matmul/compute/thread_bounds.h" |
|
#include "sparse_matmul/layers/masked_sparse_matrix.h" |
|
#include "sparse_matmul/numerics/fixed_types.h" |
|
#include "sparse_matmul/numerics/float16_types.h" |
|
#include "sparse_matmul/os/coop_threads.h" |
|
#include "sparse_matmul/vector/cache_aligned_vector.h" |
|
|
|
#include "absl/memory/memory.h" |
|
|
|
namespace csrblocksparse { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename DeltaType = int16_t> |
|
class CsrBlockSparseMatrix { |
|
public: |
|
CsrBlockSparseMatrix() {} |
|
|
|
|
|
CsrBlockSparseMatrix(const uint8_t* const& buffer, const std::size_t& len) { |
|
ReadFromFlatBuffer(buffer, len); |
|
ComputeRHSIndices(); |
|
} |
|
|
|
template <typename InputType> |
|
CsrBlockSparseMatrix(const MaskedSparseMatrix<InputType>& masked_matrix) { |
|
sparsity_ = masked_matrix.sparsity(); |
|
rows_ = masked_matrix.rows(); |
|
cols_ = masked_matrix.cols(); |
|
|
|
DetermineBlockSize(masked_matrix); |
|
|
|
if (block_width_ == 1 && block_height_ == 1) |
|
col_multiple_ = 8; |
|
else |
|
col_multiple_ = 1; |
|
|
|
std::vector<InputType> weights(masked_matrix.values().begin(), |
|
masked_matrix.values().end()); |
|
|
|
reduced_rows_ = (rows_ + block_height_ - 1) / block_height_; |
|
rows_ = reduced_rows_ * block_height_; |
|
reduced_cols_ = cols_ / block_width_; |
|
|
|
|
|
std::vector<int> reduced_mask(reduced_rows_ * reduced_cols_); |
|
std::vector<int> row_offsets = {0}; |
|
int nnz = 0; |
|
const auto& mask = masked_matrix.mask(); |
|
for (int r = 0; r < reduced_rows_; ++r) { |
|
for (int c = 0; c < reduced_cols_; ++c) { |
|
int mask_val = mask[r * block_height_ * cols_ + c * block_width_]; |
|
reduced_mask[r * reduced_cols_ + c] = mask_val; |
|
nnz += mask_val; |
|
} |
|
row_offsets.push_back(nnz); |
|
} |
|
|
|
|
|
MakeColumnsMultiple(row_offsets, &reduced_mask, &weights); |
|
|
|
std::vector<int> col_indices; |
|
std::vector<WeightType> weights_csr; |
|
std::vector<int> nnz_per_row; |
|
MaskAndWeightsToCsr(reduced_mask, weights, &nnz_per_row, &col_indices, |
|
&weights_csr); |
|
|
|
|
|
std::vector<DeltaType> col_deltas; |
|
for (int i = 0; i < col_indices.size(); ++i) { |
|
|
|
int64_t diff = sizeof(RhsType); |
|
if (i == 0) |
|
diff *= block_width_ * (col_indices[i]); |
|
else |
|
diff *= block_width_ * (col_indices[i] - col_indices[i - 1]); |
|
|
|
CHECK(diff < std::numeric_limits<DeltaType>::max()) |
|
<< "delta between column indices in bytes " << diff |
|
<< " exceeded the maximum size of the DeltaType " |
|
<< std::numeric_limits<DeltaType>::max(); |
|
col_deltas.push_back(static_cast<DeltaType>(diff)); |
|
} |
|
|
|
|
|
col_deltas.insert(col_deltas.end(), std::max(2, col_multiple_ + 1), 0); |
|
nnz_per_row.insert(nnz_per_row.end(), 2, nnz_per_row.back()); |
|
|
|
weights_ = CacheAlignedVector<WeightType>(weights_csr); |
|
col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas); |
|
nnz_per_row_ = CacheAlignedVector<int>(nnz_per_row); |
|
ComputeRHSIndices(); |
|
|
|
num_threads_ = 0; |
|
PrepareForThreads(1); |
|
} |
|
|
|
|
|
|
|
|
|
CsrBlockSparseMatrix( |
|
const CsrBlockSparseMatrix<WeightType, RhsType, DeltaType>& src_matrix, |
|
const std::vector<WeightType>& new_weights, |
|
const std::vector<DeltaType>& new_deltas, const std::vector<int>& new_nnz, |
|
int cols) { |
|
num_threads_ = 0; |
|
col_multiple_ = src_matrix.col_multiple_; |
|
block_width_ = src_matrix.block_width_; |
|
block_height_ = src_matrix.block_height_; |
|
reduced_rows_ = new_nnz.size(); |
|
rows_ = reduced_rows_ * block_height_; |
|
cols_ = cols; |
|
reduced_cols_ = cols_ / block_width_; |
|
weights_ = CacheAlignedVector<WeightType>(new_weights); |
|
col_deltas_ = CacheAlignedVector<DeltaType>(new_deltas); |
|
nnz_per_row_ = CacheAlignedVector<int>(new_nnz); |
|
sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_); |
|
ComputeRHSIndices(); |
|
name_ = src_matrix.name_; |
|
PrepareForThreads(1); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CsrBlockSparseMatrix SplitByColumn(int start_col, int end_col, |
|
bool keep_rhs_size = false) const { |
|
int weight_index = 0; |
|
int delta_index = 0; |
|
std::vector<DeltaType> new_deltas; |
|
std::vector<WeightType> new_weights; |
|
std::vector<int> new_nnz(reduced_rows_); |
|
int col = 0; |
|
int prev_col = keep_rhs_size ? 0 : start_col; |
|
for (int r = 0; r < reduced_rows_; ++r) { |
|
int reduced_col_count = nnz_per_row_[r]; |
|
for (int c = 0; c < reduced_col_count; ++c, ++delta_index) { |
|
col += col_deltas_[delta_index] / sizeof(RhsType); |
|
if ((start_col < end_col && start_col <= col && col < end_col) || |
|
(start_col > end_col && (col < end_col || col >= start_col))) { |
|
++new_nnz[r]; |
|
new_deltas.push_back((col - prev_col) * sizeof(RhsType)); |
|
prev_col = col; |
|
for (int i = 0; i < block_width_ * block_height_; |
|
++i, ++weight_index) { |
|
new_weights.push_back(weights_[weight_index]); |
|
} |
|
} else { |
|
weight_index += block_width_ * block_height_; |
|
} |
|
} |
|
} |
|
int new_cols = keep_rhs_size ? cols_ : end_col - start_col; |
|
return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, |
|
new_cols); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
CsrBlockSparseMatrix SplitByRow(int start_row, int end_row) const { |
|
int start_reduced = start_row / block_height_; |
|
int end_reduced = end_row / block_height_; |
|
std::vector<int> new_nnz(nnz_per_row_.data() + start_reduced, |
|
nnz_per_row_.data() + end_reduced); |
|
int weight_start = 0; |
|
for (int r = 0; r < start_reduced; ++r) { |
|
weight_start += nnz_per_row_[r]; |
|
} |
|
int weight_end = weight_start; |
|
for (int r = start_reduced; r < end_reduced; ++r) { |
|
weight_end += nnz_per_row_[r]; |
|
} |
|
int delta_start = 0; |
|
for (int i = 0; i < weight_start; ++i) { |
|
delta_start += col_deltas_[i]; |
|
} |
|
std::vector<DeltaType> new_deltas(col_deltas_.data() + weight_start, |
|
col_deltas_.data() + weight_end); |
|
new_deltas[0] += delta_start; |
|
int block_size = block_height_ * block_width_; |
|
std::vector<WeightType> new_weights( |
|
weights_.data() + weight_start * block_size, |
|
weights_.data() + weight_end * block_size); |
|
return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, cols_); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void DoubleBlockHeight() { |
|
int new_rows = reduced_rows_ / 2; |
|
std::vector<int> new_nnz(new_rows); |
|
std::vector<DeltaType> new_rhs_indices; |
|
std::vector<WeightType> new_weights; |
|
int rhs_index1 = 0; |
|
int rhs_index2 = 0; |
|
int block_size = block_height_ * block_width_; |
|
for (int r = 0; r < new_rows; ++r) { |
|
int start_nnz = new_rhs_indices.size(); |
|
rhs_index2 += nnz_per_row_[r * 2]; |
|
int end1 = rhs_index1 + nnz_per_row_[r * 2]; |
|
int end2 = rhs_index2 + nnz_per_row_[r * 2 + 1]; |
|
|
|
|
|
while (rhs_index1 < end1 || rhs_index2 < end2) { |
|
int col1 = rhs_index1 < end1 ? rhs_indices_[rhs_index1] : reduced_cols_; |
|
int col2 = rhs_index2 < end2 ? rhs_indices_[rhs_index2] : reduced_cols_; |
|
if (col1 < col2) { |
|
|
|
new_rhs_indices.push_back(col1); |
|
new_weights.insert(new_weights.end(), |
|
weights_.data() + rhs_index1 * block_size, |
|
weights_.data() + (rhs_index1 + 1) * block_size); |
|
new_weights.insert(new_weights.end(), block_size, |
|
static_cast<WeightType>(0.0f)); |
|
++rhs_index1; |
|
} else if (col1 > col2) { |
|
|
|
new_rhs_indices.push_back(col2); |
|
new_weights.insert(new_weights.end(), block_size, |
|
static_cast<WeightType>(0.0f)); |
|
new_weights.insert(new_weights.end(), |
|
weights_.data() + rhs_index2 * block_size, |
|
weights_.data() + (rhs_index2 + 1) * block_size); |
|
++rhs_index2; |
|
} else { |
|
|
|
new_rhs_indices.push_back(col1); |
|
new_weights.insert(new_weights.end(), |
|
weights_.data() + rhs_index1 * block_size, |
|
weights_.data() + (rhs_index1 + 1) * block_size); |
|
new_weights.insert(new_weights.end(), |
|
weights_.data() + rhs_index2 * block_size, |
|
weights_.data() + (rhs_index2 + 1) * block_size); |
|
++rhs_index1; |
|
++rhs_index2; |
|
} |
|
} |
|
rhs_index1 = rhs_index2; |
|
new_nnz[r] = new_rhs_indices.size() - start_nnz; |
|
} |
|
block_height_ *= 2; |
|
reduced_rows_ /= 2; |
|
weights_ = CacheAlignedVector<WeightType>(new_weights); |
|
rhs_indices_ = CacheAlignedVector<DeltaType>(new_rhs_indices); |
|
nnz_per_row_ = CacheAlignedVector<int>(new_nnz); |
|
sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_); |
|
ComputeColDeltas(); |
|
if (num_threads_ > 0) { |
|
int num_threads = num_threads_; |
|
num_threads_ = 0; |
|
PrepareForThreads(num_threads); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
std::size_t WriteToFlatBuffer(std::string* csr_flatbuffer) { |
|
std::size_t bytes = 0; |
|
bytes += FixedParameterSize(); |
|
bytes += weights_.size() * sizeof(WeightType); |
|
bytes += col_deltas_.size() * sizeof(DeltaType); |
|
bytes += nnz_per_row_.size() * sizeof(int); |
|
|
|
uint8_t* bytes_ptr_ptr = |
|
reinterpret_cast<uint8_t*>(CHECK_NOTNULL(malloc(bytes))); |
|
|
|
int* int_bytes_ptr = reinterpret_cast<int*>(bytes_ptr_ptr); |
|
|
|
*int_bytes_ptr++ = rows_; |
|
*int_bytes_ptr++ = cols_; |
|
*int_bytes_ptr++ = reduced_rows_; |
|
*int_bytes_ptr++ = reduced_cols_; |
|
*int_bytes_ptr++ = block_width_; |
|
*int_bytes_ptr++ = block_height_; |
|
*int_bytes_ptr++ = col_multiple_; |
|
*int_bytes_ptr++ = num_threads_; |
|
*int_bytes_ptr++ = weights_.size(); |
|
*int_bytes_ptr++ = col_deltas_.size(); |
|
*int_bytes_ptr++ = nnz_per_row_.size(); |
|
|
|
float* float_bytes_ptr = reinterpret_cast<float*>(int_bytes_ptr); |
|
*float_bytes_ptr++ = sparsity_; |
|
|
|
uint8_t* bytes_ptr = reinterpret_cast<uint8_t*>(float_bytes_ptr); |
|
|
|
memcpy(bytes_ptr, weights_.data(), weights_.size() * sizeof(WeightType)); |
|
bytes_ptr += weights_.size() * sizeof(WeightType); |
|
|
|
memcpy(bytes_ptr, col_deltas_.data(), |
|
col_deltas_.size() * sizeof(DeltaType)); |
|
bytes_ptr += col_deltas_.size() * sizeof(DeltaType); |
|
|
|
memcpy(bytes_ptr, nnz_per_row_.data(), nnz_per_row_.size() * sizeof(int)); |
|
bytes_ptr += nnz_per_row_.size() * sizeof(int); |
|
|
|
csr_flatbuffer->resize(bytes); |
|
csr_flatbuffer->assign(reinterpret_cast<char*>(bytes_ptr_ptr), bytes); |
|
free(bytes_ptr_ptr); |
|
|
|
return bytes; |
|
} |
|
|
|
void ReadFromFlatBuffer(const uint8_t* const& bytes, const std::size_t& len) { |
|
CHECK_GE(len, FixedParameterSize()); |
|
|
|
const int* int_bytes_ptr = reinterpret_cast<const int*>(bytes); |
|
rows_ = *int_bytes_ptr++; |
|
cols_ = *int_bytes_ptr++; |
|
reduced_rows_ = *int_bytes_ptr++; |
|
reduced_cols_ = *int_bytes_ptr++; |
|
block_width_ = *int_bytes_ptr++; |
|
block_height_ = *int_bytes_ptr++; |
|
col_multiple_ = *int_bytes_ptr++; |
|
int num_threads = *int_bytes_ptr++; |
|
int32_t weights_size = *int_bytes_ptr++; |
|
int32_t col_deltas_size = *int_bytes_ptr++; |
|
int32_t nnz_per_row_size = *int_bytes_ptr++; |
|
|
|
|
|
weights_size = std::max(0, weights_size); |
|
col_deltas_size = std::max(0, col_deltas_size); |
|
nnz_per_row_size = std::max(0, nnz_per_row_size); |
|
|
|
const float* float_bytes_ptr = |
|
reinterpret_cast<const float*>(int_bytes_ptr); |
|
sparsity_ = *float_bytes_ptr++; |
|
|
|
std::size_t total_bytes = |
|
FixedParameterSize() + weights_size * sizeof(WeightType) + |
|
col_deltas_size * sizeof(DeltaType) + nnz_per_row_size * sizeof(int); |
|
|
|
CHECK_EQ(total_bytes, len) |
|
<< "total bytes: " << total_bytes << ", actual len given: " << len; |
|
|
|
const uint8_t* bytes_ptr = |
|
reinterpret_cast<const uint8_t*>(float_bytes_ptr); |
|
std::vector<WeightType> weights_raw(weights_size); |
|
memcpy(weights_raw.data(), bytes_ptr, weights_size * sizeof(WeightType)); |
|
weights_ = CacheAlignedVector<WeightType>(weights_raw); |
|
bytes_ptr += weights_size * sizeof(WeightType); |
|
|
|
std::vector<DeltaType> deltas_raw(col_deltas_size); |
|
memcpy(deltas_raw.data(), bytes_ptr, col_deltas_size * sizeof(DeltaType)); |
|
col_deltas_ = CacheAlignedVector<DeltaType>(deltas_raw); |
|
bytes_ptr += col_deltas_size * sizeof(DeltaType); |
|
|
|
std::vector<int> nnz_raw(nnz_per_row_size); |
|
memcpy(nnz_raw.data(), bytes_ptr, nnz_per_row_size * sizeof(int)); |
|
nnz_per_row_ = CacheAlignedVector<int>(nnz_raw); |
|
num_threads_ = 0; |
|
PrepareForThreads(num_threads); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename RhsClass, typename BiasClass, typename OutClass, |
|
typename BiasType = typename BiasClass::value_type, |
|
typename OutType = typename OutClass::value_type> |
|
void SpMM_bias(const RhsClass& rhs, const BiasClass& bias, OutClass* out, |
|
bool relu = false, int tid = 0, |
|
SpinBarrier* barrier = nullptr) const { |
|
static_assert(std::is_same<typename RhsClass::value_type, RhsType>::value, |
|
"Rhs types must match"); |
|
CHECK_LT(tid, num_threads_); |
|
CHECK_EQ(rhs.cols(), out->cols()); |
|
CHECK_EQ(rhs.rows(), cols_); |
|
CHECK_GE(out->rows(), rows_); |
|
int cols_to_go = out->cols(); |
|
int rhs_index = *thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid); |
|
const RhsType* rhs_ptr = rhs.data() + rhs_index * block_height_; |
|
OutType* out_ptr = thread_bounds_.OffsetOutput(out->data(), tid); |
|
const WeightType* weights_ptr = |
|
thread_bounds_.OffsetWeights(weights_.data(), tid); |
|
const DeltaType* delta_ptr = |
|
thread_bounds_.OffsetRhsIndices(col_deltas_.data(), tid); |
|
int offset = *delta_ptr / sizeof(RhsType); |
|
rhs_ptr -= offset; |
|
const int* nnz_ptr = nnz_per_row_.data() + thread_bounds_.StartRow(tid); |
|
int assigned_rows = |
|
thread_bounds_.StartRow(tid + 1) - thread_bounds_.StartRow(tid); |
|
const BiasType* bias_ptr = thread_bounds_.OffsetBias(bias.data(), tid); |
|
|
|
while (cols_to_go > 0) { |
|
if (block_width_ == 4 && block_height_ == 4) { |
|
if (cols_to_go >= 5) { |
|
detail::SpMM5_4x4<WeightType, RhsType, OutType>( |
|
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, |
|
assigned_rows, out->col_stride(), rhs.col_stride(), relu); |
|
} else { |
|
detail::SpMV_4x4<WeightType, RhsType, OutType>( |
|
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, |
|
assigned_rows, out->col_stride(), rhs.col_stride(), relu); |
|
} |
|
} else { |
|
if (cols_to_go >= 5) { |
|
detail::SpMM5_1x1<WeightType, RhsType, OutType>( |
|
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, |
|
assigned_rows, out->col_stride(), rhs.col_stride(), relu); |
|
} else { |
|
detail::SpMV_1x1<WeightType, RhsType, OutType>( |
|
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, |
|
assigned_rows, out->col_stride(), rhs.col_stride(), relu); |
|
} |
|
} |
|
|
|
if (cols_to_go >= 5) { |
|
cols_to_go -= 5; |
|
rhs_ptr += rhs.col_stride() * 5; |
|
out_ptr += out->col_stride() * 5; |
|
} else { |
|
cols_to_go--; |
|
rhs_ptr += rhs.col_stride(); |
|
out_ptr += out->col_stride(); |
|
} |
|
if (barrier) barrier->barrier(); |
|
} |
|
} |
|
template <typename MVRhsType, typename MVBiasType, typename OutType> |
|
void MatVec(const MVRhsType* rhs, const MVBiasType* bias, bool relu, int tid, |
|
int replicas, int output_stride, OutType* output) { |
|
CHECK_LT(tid, num_threads_); |
|
CHECK_EQ(block_width_, 4) << "Block width must be 4!"; |
|
if (block_height_ == 8) { |
|
matmul_.MatVec8x4( |
|
thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs, |
|
thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(), |
|
thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid), |
|
thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu, |
|
replicas, output_stride, thread_bounds_.OffsetOutput(output, tid)); |
|
} else { |
|
CHECK_EQ(block_height_, 4) << "Block height must be 4 or 8!"; |
|
matmul_.MatVec4x4( |
|
thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs, |
|
thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(), |
|
thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid), |
|
thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu, |
|
replicas, output_stride, thread_bounds_.OffsetOutput(output, tid)); |
|
} |
|
} |
|
|
|
int rows() const { return rows_; } |
|
int cols() const { return cols_; } |
|
int block_height() const { return block_height_; } |
|
int block_width() const { return block_width_; } |
|
float sparsity() const { return sparsity_; } |
|
int num_threads() const { return num_threads_; } |
|
const ThreadBounds& thread_bounds() const { return thread_bounds_; } |
|
const CacheAlignedVector<DeltaType>& rhs_indices() const { |
|
return rhs_indices_; |
|
} |
|
const std::string& name() const { return name_; } |
|
void set_name(const std::string& name) { name_ = name; } |
|
const std::vector<int>& split_points() const { |
|
return thread_bounds_.row_starts(); |
|
} |
|
|
|
std::size_t bytes() const { |
|
return weights_.size() * sizeof(WeightType) + |
|
col_deltas_.size() * sizeof(DeltaType) + |
|
nnz_per_row_.size() * sizeof(int); |
|
} |
|
|
|
|
|
|
|
template <typename RhsClass, typename BiasClass, typename OutClass, |
|
typename BiasType = typename BiasClass::value_type, |
|
typename OutType = typename OutClass::value_type> |
|
typename std::enable_if<!IsFixed32Type<OutType>::value, int>::type |
|
SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out, |
|
float temperature, int tid, SpinBarrier* barrier, |
|
std::minstd_rand* gen, |
|
CacheAlignedVector<float>* scratch) const { |
|
SpMM_bias(rhs, bias, out, false, tid, barrier); |
|
return out->Sample(temperature, gen, scratch); |
|
} |
|
|
|
template <typename RhsClass, typename BiasClass, typename OutClass, |
|
typename BiasType = typename BiasClass::value_type, |
|
typename OutType = typename OutClass::value_type> |
|
typename std::enable_if<IsFixed32Type<OutType>::value, int>::type |
|
SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out, |
|
float temperature, int tid, SpinBarrier* barrier, |
|
std::minstd_rand* gen, |
|
CacheAlignedVector<float>* scratch) const { |
|
|
|
SpMM_bias(rhs, bias, out, false, tid); |
|
return out->ReducingSample(gen, scratch, tid, temperature, barrier); |
|
} |
|
|
|
void Print() const { |
|
std::cout << "Weights\n"; |
|
weights_.Print(); |
|
std::cout << std::endl; |
|
std::cout << "Deltas\n"; |
|
col_deltas_.Print(); |
|
std::cout << std::endl; |
|
std::cout << "nnz\n"; |
|
nnz_per_row_.Print(); |
|
std::cout << std::endl; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename OutType = int32_t> |
|
int PrepareForThreads(int num_threads, int cache_line_size = -1) { |
|
CHECK_GT(num_threads, 0); |
|
|
|
if (num_threads == num_threads_) return num_threads_; |
|
|
|
num_threads_ = num_threads; |
|
thread_bounds_.PrepareForThreads( |
|
block_width_, block_height_, num_threads_, |
|
ReducedRowsPerCacheLine<OutType>(cache_line_size), reduced_rows_, |
|
nnz_per_row_.data()); |
|
return num_threads_; |
|
} |
|
|
|
|
|
void ComputeRHSIndices() { |
|
std::vector<int> cumulative_deltas = CumulativeColDeltas(); |
|
std::vector<DeltaType> rhs_indices(cumulative_deltas.size() + |
|
reduced_rows_); |
|
int total_indices = 0; |
|
int delta_index = 0; |
|
for (int r = 0; r < reduced_rows_; ++r) { |
|
for (int n = 0; n < nnz_per_row_[r]; ++n, ++delta_index) { |
|
rhs_indices[total_indices++] = |
|
cumulative_deltas[delta_index] / block_width_; |
|
} |
|
} |
|
rhs_indices_ = CacheAlignedVector<DeltaType>(rhs_indices); |
|
} |
|
|
|
|
|
void ComputeColDeltas() { |
|
std::vector<int> col_deltas(rhs_indices_.size()); |
|
int prev_index = 0; |
|
for (int i = 0; i < rhs_indices_.size(); ++i) { |
|
int offset = rhs_indices_[i] - prev_index; |
|
prev_index = rhs_indices_[i]; |
|
col_deltas[i] = offset * block_width_ * sizeof(RhsType); |
|
} |
|
col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas); |
|
} |
|
|
|
|
|
|
|
std::vector<int> CumulativeColDeltas() const { |
|
std::vector<int> cum_col_deltas(col_deltas_.size()); |
|
for (int i = 0; i < col_deltas_.size(); ++i) { |
|
cum_col_deltas[i] = col_deltas_[i] / sizeof(RhsType); |
|
if (i > 0) cum_col_deltas[i] += cum_col_deltas[i - 1]; |
|
} |
|
return cum_col_deltas; |
|
} |
|
|
|
private: |
|
constexpr std::size_t FixedParameterSize() const { |
|
return sizeof(int) |
|
+ sizeof(int) |
|
+ sizeof(int) |
|
+ sizeof(int) |
|
+ sizeof(int) |
|
+ sizeof(int) |
|
+ sizeof(float) |
|
+ sizeof(int) |
|
+ sizeof(int) |
|
+ sizeof(int) |
|
+ sizeof(int) |
|
+ sizeof(int); |
|
} |
|
|
|
|
|
template <typename InputType> |
|
void DetermineBlockSize(const MaskedSparseMatrix<InputType>& masked_matrix) { |
|
const std::vector<std::pair<int, int>> kPreferredOrder = {{4, 4}}; |
|
int rows = masked_matrix.rows(); |
|
int cols = masked_matrix.cols(); |
|
|
|
for (const auto& block_size : kPreferredOrder) { |
|
int block_height, block_width; |
|
std::tie(block_height, block_width) = block_size; |
|
if (cols % block_width != 0) continue; |
|
|
|
int reduced_rows = (rows + block_height - 1) / block_height; |
|
int reduced_cols = cols / block_width; |
|
|
|
|
|
bool all_same = true; |
|
const auto& mask = masked_matrix.mask(); |
|
for (int r = 0; r < reduced_rows; ++r) { |
|
for (int c = 0; c < reduced_cols; ++c) { |
|
int val = mask[r * block_height * cols + c * block_width]; |
|
for (int i = 0; i < block_height; ++i) { |
|
for (int j = 0; j < block_width; ++j) { |
|
int index = (r * block_height + i) * cols + c * block_width + j; |
|
if (index < masked_matrix.mask().size()) { |
|
all_same &= (masked_matrix.mask()[index] == val); |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
if (all_same) { |
|
block_height_ = block_height; |
|
block_width_ = block_width; |
|
return; |
|
} |
|
} |
|
|
|
|
|
block_height_ = 1; |
|
block_width_ = 1; |
|
} |
|
|
|
|
|
template <typename InputType> |
|
void MakeColumnsMultiple(const std::vector<int>& row_offsets, |
|
std::vector<int>* reduced_mask, |
|
std::vector<InputType>* weights) { |
|
if (col_multiple_ > 0) { |
|
|
|
|
|
for (int r = 1; r < row_offsets.size(); ++r) { |
|
int num_row = row_offsets[r] - row_offsets[r - 1]; |
|
int num_needed = col_multiple_ - num_row % col_multiple_; |
|
if (num_needed < col_multiple_) { |
|
|
|
int num_added = 0; |
|
for (int c = 0; c < reduced_cols_; ++c) { |
|
if ((*reduced_mask)[(r - 1) * reduced_cols_ + c] == 0) { |
|
(*reduced_mask)[(r - 1) * reduced_cols_ + c] = 1; |
|
|
|
|
|
for (int i = 0; i < block_height_; ++i) { |
|
for (int j = 0; j < block_width_; ++j) { |
|
(*weights)[((r - 1) * block_height_ + i) * cols_ + |
|
block_width_ * c + j] = InputType(0.f); |
|
} |
|
} |
|
num_added++; |
|
} |
|
|
|
if (num_added == num_needed) break; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
template <typename InputType> |
|
void MaskAndWeightsToCsr(const std::vector<int>& mask, |
|
const std::vector<InputType>& weights, |
|
std::vector<int>* nnz_per_row, |
|
std::vector<int>* col_indices, |
|
std::vector<WeightType>* weights_csr) { |
|
std::vector<int> row_offsets = {0}; |
|
int nnz = 0; |
|
|
|
if (block_width_ == 1 && block_height_ == 1) { |
|
for (int r = 0; r < rows_; ++r) { |
|
for (int c = 0; c < cols_; ++c) { |
|
if (mask[r * cols_ + c] == 1) { |
|
nnz++; |
|
col_indices->push_back(c); |
|
weights_csr->push_back(WeightType(weights[r * cols_ + c])); |
|
} |
|
} |
|
row_offsets.push_back(nnz); |
|
} |
|
} else if (block_width_ == 4 && block_height_ == 4) { |
|
|
|
for (int r = 0; r < reduced_rows_; ++r) { |
|
for (int c = 0; c < reduced_cols_; ++c) { |
|
if (mask[r * reduced_cols_ + c] == 1) { |
|
col_indices->push_back(c); |
|
nnz++; |
|
for (int i = 0; i < block_height_; ++i) { |
|
for (int j = 0; j < block_width_; ++j) { |
|
int row_index = (block_height_ * r + i) * cols_; |
|
int w_index = row_index + block_width_ * c + j; |
|
WeightType weight = w_index < weights.size() |
|
? WeightType(weights[w_index]) |
|
: WeightType(0.0f); |
|
weights_csr->push_back(weight); |
|
} |
|
} |
|
} |
|
} |
|
row_offsets.push_back(nnz); |
|
} |
|
} |
|
for (int i = 1; i < row_offsets.size(); ++i) |
|
nnz_per_row->push_back(row_offsets[i] - row_offsets[i - 1]); |
|
} |
|
|
|
|
|
|
|
template <typename OutType> |
|
int ReducedRowsPerCacheLine(int override_cache_line_size = -1) const { |
|
int line_size = kCacheLineSize; |
|
if (override_cache_line_size >= 1) line_size = override_cache_line_size; |
|
return std::max<int>(line_size / (block_height_ * sizeof(OutType)), 1); |
|
} |
|
|
|
int col_multiple_; |
|
int rows_; |
|
int cols_; |
|
int reduced_rows_; |
|
int reduced_cols_; |
|
float sparsity_; |
|
int block_width_; |
|
int block_height_; |
|
int num_threads_; |
|
std::string name_; |
|
|
|
CacheAlignedVector<WeightType> weights_; |
|
CacheAlignedVector<DeltaType> col_deltas_; |
|
CacheAlignedVector<int> nnz_per_row_; |
|
|
|
|
|
CacheAlignedVector<DeltaType> rhs_indices_; |
|
Matmul<WeightType, RhsType> matmul_; |
|
ThreadBounds thread_bounds_; |
|
static constexpr int kCacheLineSize = 64; |
|
}; |
|
|
|
|
|
|
|
template <typename MaskType> |
|
std::string ConvertDenseToSparseRepresentation_Int16Deltas( |
|
const std::vector<MaskType>& mask, const std::vector<float>& weights, |
|
const int rows, const int cols) { |
|
MaskedSparseMatrix<float> masked_weights(rows, cols, mask.data(), |
|
weights.data()); |
|
CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t> |
|
sparse_masked_weights(masked_weights); |
|
std::string buffer; |
|
sparse_masked_weights.WriteToFlatBuffer(&buffer); |
|
return buffer; |
|
} |
|
|
|
} |
|
#endif |
|
|