|
|
|
|
|
|
|
|
|
|
|
#include <pybind11/numpy.h> |
|
#include <pybind11/pybind11.h> |
|
#include <pybind11/stl.h> |
|
|
|
#include <iostream> |
|
#include <random> |
|
#include <vector> |
|
|
|
#include "sparse_matmul/sparse_matmul.h" |
|
namespace py = pybind11; |
|
using namespace std; |
|
|
|
using fvec = std::vector<float>; |
|
using ivec = std::vector<int>; |
|
using fndarray = py::array_t<float>; |
|
using indarray = py::array_t<int>; |
|
using mat = csrblocksparse::CsrBlockSparseMatrix<float, float, int16_t>; |
|
using vec = csrblocksparse::CacheAlignedVector<float>; |
|
using masked_mat = csrblocksparse::MaskedSparseMatrix<float>; |
|
|
|
mat create_mat(int h, int w) { |
|
auto m = masked_mat(w, h, 0.90, 4, 4, 0.0, true); |
|
auto a = mat(m); |
|
return a; |
|
} |
|
|
|
struct WaveGRU { |
|
int hidden_dim; |
|
int repeat_factor; |
|
mat m; |
|
vec b; |
|
vec z, r, hh, zrh; |
|
vec fco1, fco2; |
|
vec o1b, o2b; |
|
vec t; |
|
vec h; |
|
vec logits; |
|
mat o1, o2; |
|
std::vector<vec> embed; |
|
|
|
WaveGRU(int hidden_dim, int repeat_factor) |
|
: hidden_dim(hidden_dim), |
|
repeat_factor(repeat_factor), |
|
b(3*hidden_dim), |
|
t(3*hidden_dim), |
|
zrh(3*hidden_dim), |
|
z(hidden_dim), |
|
r(hidden_dim), |
|
hh(hidden_dim), |
|
fco1(hidden_dim), |
|
fco2(256), |
|
h(hidden_dim), |
|
o1b(hidden_dim), |
|
o2b(256), |
|
logits(256) { |
|
m = create_mat(hidden_dim, 3*hidden_dim); |
|
o1 = create_mat(hidden_dim, hidden_dim); |
|
o2 = create_mat(hidden_dim, 256); |
|
embed = std::vector<vec>(); |
|
for (int i = 0; i < 256; i++) { |
|
embed.emplace_back(hidden_dim * 3); |
|
embed[i].FillRandom(); |
|
} |
|
} |
|
|
|
void load_embed(fndarray embed_weights) { |
|
auto a_embed = embed_weights.unchecked<2>(); |
|
for (int i = 0; i < 256; i++) { |
|
for (int j = 0; j < hidden_dim * 3; j++) embed[i][j] = a_embed(i, j); |
|
} |
|
} |
|
|
|
mat load_linear(vec& bias, fndarray w, indarray mask, fndarray b) { |
|
auto w_ptr = static_cast<float*>(w.request().ptr); |
|
auto mask_ptr = static_cast<int*>(mask.request().ptr); |
|
auto rb = b.unchecked<1>(); |
|
|
|
for (int i = 0; i < rb.shape(0); i++) bias[i] = rb(i) / 4; |
|
|
|
masked_mat mm(w.shape(0), w.shape(1), mask_ptr, w_ptr); |
|
mat mmm(mm); |
|
return mmm; |
|
} |
|
|
|
void load_weights(fndarray m, indarray m_mask, fndarray b, |
|
fndarray o1, indarray o1_mask, |
|
fndarray o1b, fndarray o2, |
|
indarray o2_mask, fndarray o2b) { |
|
this->m = load_linear(this->b, m, m_mask, b); |
|
this->o1 = load_linear(this->o1b, o1, o1_mask, o1b); |
|
this->o2 = load_linear(this->o2b, o2, o2_mask, o2b); |
|
} |
|
|
|
std::vector<int> inference(fndarray ft, float temperature) { |
|
auto rft = ft.unchecked<2>(); |
|
int value = 127; |
|
std::vector<int> signal(rft.shape(0) * repeat_factor); |
|
h.FillZero(); |
|
for (int index = 0; index < signal.size(); index++) { |
|
m.SpMM_bias(h, b, &zrh, false); |
|
|
|
for (int i = 0; i < 3 * hidden_dim; i++) t[i] = embed[value][i] + rft(index / repeat_factor, i); |
|
for (int i = 0; i < hidden_dim; i++) { |
|
z[i] = zrh[i] + t[i]; |
|
r[i] = zrh[hidden_dim + i] + t[hidden_dim + i]; |
|
} |
|
|
|
z.Sigmoid(); |
|
r.Sigmoid(); |
|
|
|
for (int i = 0; i < hidden_dim; i++) { |
|
hh[i] = zrh[hidden_dim * 2 + i] * r[i] + t[hidden_dim * 2 + i]; |
|
} |
|
hh.Tanh(); |
|
for (int i = 0; i < hidden_dim; i++) { |
|
h[i] = (1. - z[i]) * h[i] + z[i] * hh[i]; |
|
} |
|
o1.SpMM_bias(h, o1b, &fco1, true); |
|
o2.SpMM_bias(fco1, o2b, &fco2, false); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
value = fco2.Sample(temperature); |
|
signal[index] = value; |
|
} |
|
return signal; |
|
} |
|
}; |
|
|
|
PYBIND11_MODULE(wavegru_mod, m) { |
|
py::class_<WaveGRU>(m, "WaveGRU") |
|
.def(py::init<int, int>()) |
|
.def("load_embed", &WaveGRU::load_embed) |
|
.def("load_weights", &WaveGRU::load_weights) |
|
.def("inference", &WaveGRU::inference); |
|
} |
|
|