ntt123's picture
add sparse matmul
21f3d42
raw
history blame
7.49 kB
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
#include <algorithm>
#include <cstdio>
#include <numeric>
#include <vector>
#include "absl/strings/str_format.h"
#include "sparse_matmul/vector/cache_aligned_vector.h"
namespace csrblocksparse {
// MaskedSparseMatrix serves two purposes:
// 1) It is useful as a reference implementation of SpMV for correctness
// checking the much more complicated implementations in CSRBlockSparseMatrix
// 2) This is the format that sparse matrices are represented after pruning
// in TF. This class provides a bridge to getting these parameters into
// a compressed form suitable for computation and serialization.
//
// MaskedSparseMatrix<float> matrix(rows, cols, mask_from_tf, values_from_tf);
// CSRBlockSparseMatrix<float, bfloat16, int16_t> csr_matrix(matrix);
// csr_matrix.Multiply(rhs, bias, &out);
template <typename T>
class MaskedSparseMatrix {
public:
MaskedSparseMatrix() {}
// Construct a MaskedSparseMatrix of the given size, sparsity and block size.
// This is mainly useful for testing.
MaskedSparseMatrix(int rows, int cols, float sparsity, int block_height = 1,
int block_width = 1, float constant = 1.f,
bool random = true)
: rows_(rows), cols_(cols), sparsity_(sparsity) {
CHECK_EQ(rows % block_height, 0);
CHECK_EQ(cols % block_width, 0);
init(sparsity, block_height, block_width, constant, random);
}
// Construct from an existing mask and values (most likely from a TF model).
template <typename MaskType>
MaskedSparseMatrix(int rows, int cols, const MaskType* mask, const T* values)
: rows_(rows), cols_(cols) {
mask_.resize(rows * cols);
values_.resize(rows * cols);
std::copy_n(mask, rows * cols, mask_.begin());
std::copy_n(values, rows * cols, values_.begin());
sparsity_ =
1.f - std::accumulate(mask_.begin(), mask_.end(), 0.f) / mask_.size();
}
const std::vector<int>& mask() const { return mask_; }
const std::vector<T>& values() const { return values_; }
T* data() { return values_.data(); }
const T* data() const { return values_.data(); }
int rows() const { return rows_; }
int cols() const { return cols_; }
float sparsity() const { return sparsity_; }
void Print() const {
absl::PrintF("-------Values---------\n");
for (int r = 0; r < rows_; ++r) {
for (int c = 0; c < cols_; ++c) {
absl::PrintF("%+6.3f ", static_cast<float>(values_[r * cols_ + c]));
}
absl::PrintF("\n");
}
absl::PrintF("-------Mask---------\n");
for (int r = 0; r < rows_; ++r) {
for (int c = 0; c < cols_; ++c) {
printf("%2d ", mask_[r * cols_ + c]);
}
absl::PrintF("\n");
}
}
// This routine is useful for rounding the possibly higher precision values
// stored in this class to a lower precision, so that correctness checks
// between this class and CSRBlockSparseMatrix can have a tighter tolerance.
template <typename U>
void CastWeights() {
for (int i = 0; i < values_.size(); ++i) {
values_[i] = static_cast<T>(U(values_[i]));
}
}
// Only meant for correctness checking.
// RhsClassType is meant to be either CacheAlignedVector OR
// FatCacheAlignedVector.
// The weight matrix is ROW MAJOR and RhsClassType is COLUMN MAJOR.
// |bias| is broadcast if |rhs| has more than one column.
template <typename RhsClassType, typename BiasType, typename OutClassType,
typename RhsType = typename RhsClassType::value_type,
typename OutType = typename OutClassType::value_type>
void SpMM_bias(const RhsClassType& rhs,
const CacheAlignedVector<BiasType>& bias, OutClassType* out,
bool relu = false) {
for (int r = 0; r < rows_; ++r) {
for (int n = 0; n < rhs.cols(); ++n) {
float sum = 0.f;
const RhsType* rhs_ptr = rhs.data() + n * rhs.rows();
OutType* out_ptr = out->data() + n * out->rows();
const int* mask_ptr = mask_.data() + r * cols_;
const T* value_ptr = values_.data() + r * cols_;
for (int c = 0; c < cols_; ++c) {
sum += mask_ptr[c] * static_cast<float>(value_ptr[c]) *
static_cast<float>(rhs_ptr[c]);
}
out_ptr[r] = static_cast<OutType>(
relu ? std::max(sum + static_cast<float>(bias[r]), 0.f)
: sum + static_cast<float>(bias[r]));
}
}
}
private:
// Generate a random matrix with the specified sparsity.
// Useful for testing.
void init(float sparsity, int block_height, int block_width, float constant,
bool random = true) {
int reduced_rows = rows_ / block_height;
int reduced_cols = cols_ / block_width;
mask_.resize(rows_ * cols_, 0);
// Fill with non-zero value to make sure masking works.
values_.resize(rows_ * cols_, static_cast<T>(2.f));
std::mt19937 generator(0);
std::uniform_real_distribution<float> dist_sparsity;
std::uniform_real_distribution<float> dist_value(-1.f, 1.f);
int nnz = 0;
while (nnz == 0) {
for (int r = 0; r < reduced_rows; ++r) {
for (int c = 0; c < reduced_cols; ++c) {
if (dist_sparsity(generator) > sparsity) {
nnz++;
for (int i = 0; i < block_height; ++i) {
for (int j = 0; j < block_width; ++j) {
mask_[(r * block_height + i) * cols_ + block_width * c + j] = 1;
values_[(r * block_height + i) * cols_ + block_width * c + j] =
static_cast<T>(random ? dist_value(generator) : constant);
}
}
}
}
}
}
}
std::vector<int> mask_;
std::vector<T> values_;
int rows_;
int cols_;
float sparsity_;
};
template <typename T>
class MaskedLinearLayer {
public:
MaskedLinearLayer(MaskedSparseMatrix<T>&& weights,
CacheAlignedVector<T>&& bias)
: weights_(std::move(weights)), bias_(std::move(bias)) {}
MaskedLinearLayer() {}
template <typename U>
void CastWeights() {
weights_.template CastWeights<U>();
}
// Does Ax + b where A is a masked sparse ROW MAJOR matrix and
// x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is
// broadcast is rhs has more than one column.
template <typename FatVector>
void SpMM_bias(const FatVector& rhs, FatVector* out, bool relu = false) {
static_assert(std::is_same<typename FatVector::value_type, T>::value,
"FatVector value_type must match masked_linear_layer type");
weights_.SpMM_bias(rhs, bias_, out, relu);
}
private:
MaskedSparseMatrix<T> weights_;
CacheAlignedVector<T> bias_;
};
} // namespace csrblocksparse
#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_