/* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FAST_TRANSCENDENTALS_H_ #define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FAST_TRANSCENDENTALS_H_ #include #if defined __ARM_NEON || defined __aarch64__ #include #else #include #endif #if defined __AVX__ || defined __AVX2__ #include #endif #include #include "sparse_matmul/numerics/fixed_types.h" #include "sparse_matmul/numerics/type_utils.h" namespace csrblocksparse { // The input to exp is clipped to bounds that prevent overflow/underflow in a // 32 bit float representation. e^80 ~ 6e34, which is close to maxfloat. constexpr float kMaxExpInput = 80.f; constexpr int kMaxExpInputInt = static_cast(kMaxExpInput); constexpr float kMinExpInput = -80.f; // tanh(9) ~ 0.99999997, which cannot be resolved from 1 in a float32. constexpr float kMaxTanhInput = 9.f; constexpr float kMinTanhInput = -9.f; // sigmoid(18) ~ 0.999999985, which cannot be resolved from 1 in a float32. constexpr float kMaxSigmoidInput = 18.f; constexpr float kMinSigmoidInput = -18.f; // kAConstant ~= 2^23 / ln 2 constexpr uint32_t kAConstant = 0x4b38aa3b; // kBConstant ~= (127 << 23) - 366000 constexpr uint32_t kBConstant = 0x4e7de9a9; // Coefficients of the rational approximation to tanh. // Coefficients of the numerator polynomial (odd). constexpr float kTanhAlpha1 = 4.89352455891786e-03; constexpr float kTanhAlpha3 = 6.37261928875436e-04; constexpr float kTanhAlpha5 = 1.48572235717979e-05; constexpr float kTanhAlpha7 = 5.12229709037114e-08; constexpr float kTanhAlpha9 = -8.60467152213735e-11; constexpr float kTanhAlpha11 = 2.00018790482477e-13; constexpr float kTanhAlpha13 = -2.76076847742355e-16; // The monomial coefficients of the denominator polynomial (even). constexpr float kTanhBeta0 = 4.89352518554385e-03; constexpr float kTanhBeta2 = 2.26843463243900e-03; constexpr float kTanhBeta4 = 1.18534705686654e-04; constexpr float kTanhBeta6 = 1.19825839466702e-06; // Coefficients of the rational approximation to sigmoid. // Coefficients of the numerator polynomial (odd). constexpr float kSigmoidAlpha1 = 2.48287947061529e-01; constexpr float kSigmoidAlpha3 = 8.51377133304701e-03; constexpr float kSigmoidAlpha5 = 6.08574864600143e-05; constexpr float kSigmoidAlpha7 = 1.15627324459942e-07; constexpr float kSigmoidAlpha9 = 4.37031012579801e-11; // The monomial coefficients of the denominator polynomial (even). constexpr float kSigmoidBeta0 = 9.93151921023180e-01; constexpr float kSigmoidBeta2 = 1.16817656904453e-01; constexpr float kSigmoidBeta4 = 1.70198817374094e-03; constexpr float kSigmoidBeta6 = 6.29106785017040e-06; constexpr float kSigmoidBeta8 = 5.76102136993427e-09; constexpr float kSigmoidBeta10 = 6.10247389755681e-13; // x is the first term of the Taylor series approximation of tanh near 0 and // because the leading error term of tanh(x) - x is O(x^3), it is good for a // wide interval, use it in this region where the other approximation is // inaccurate. tanh(x) = x - x^3 / 3 + 2x^5 / 15 - 17x^7 / 315 + ... // Similarly for sigmoid where the first term is .25x constexpr float kTanhLinearRegion = .15f; constexpr float kSigmoidLinearRegion = .75f; // Maximum shift factor for 1/log 2 to keep it inside int32. constexpr int kMaxLog2Shift = 30; static const int kLogFactor = static_cast((1 << kMaxLog2Shift) / log(2.f)); static const float kOneOverLog2 = 1.0f / log(2.f); // Number of real mantissa bits in IEEE float32. constexpr int kFloatMantissaBits = 23; // Offset to correct the exponent value in the resulting float. constexpr int kFloatExponentOffset = 127 << kFloatMantissaBits; // Mask for mantissa. constexpr int kFloatMantissaMask = (1 << kFloatMantissaBits) - 1; // Mask for exponent; constexpr int kFloatExponentMask = (-1) ^ kFloatMantissaMask; // ========== COMMON DOCUMENTATION FOR THE FLOATING EXPONENT TRICK ============ // Summary: Use the exponent-mantissa representation of a floating point number // to give exponentiation of 2 for free. If we desire f(z) = e^z = 2^(x+n), (for // some fixed-point z expressed as an integer with imaginary binary point within // it) then we have to compute x+n = z / ln 2 and then splitting x+n into // n = int(x+n) and x = fract(x+n) in [0, 1), we can use n and 2^x as the // exponent and mantissa of a floating point number, and that float is equal to // e^z. For original reference see: // http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.9.4508&rep=rep1&type=pdf // Important detail: // IEEE floats are stored normalized, ie 1.bbbbbbb... x 2^exponent. The leading // 1 bit is not actually stored, (as it is always 1), providing an extra bit of // precision. // Since 2^0=1 and 2^1=2, we can treat the problem as 2^x = 1 + u and we thus // need a mapping x in [0, 1) -> u in [0, 1) and the 1 + is provided by the // representation. // In the original paper cited above, the mapping is u = x - c, where c is set // to minimize the average error. The function to compute exp(x) this way is // incredibly simple and computationally cheap, but not very accurate. // Fortunately, the problem has been reduced to u = 2^x - 1 over [0, 1) for // which it is far easier to construct accurate approximations with small // polynomials than a full range exp(x), and this is what the cubic and quartic // versions below do. An important feature of these functions is that they // constrain the solution to be exact at 0 and 1 so there is continuity at each // integer boundary where we wrap from 1 to 0 and increment the power of 2. // Coefficients for quartic representation of 2^x - 1 for x on [0,1). // The quartic representation is 2^x - 1 ~ x - x(1-x)(ax^2 + bx + c), hence the // coefficients of a quadratic are all that is required. // Coefficients came from numerical experiments. constexpr float kExpQuarticFactor2 = 0.0135302434f; constexpr float kExpQuarticFactor1 = 0.0656107542f; constexpr float kExpQuarticFactor0 = 0.306963906f; // Coefficients for cubic representation of 2^x - 1 for x on [0,1] // The cubic representation is 2^x - 1 ~ x - x(1-x)(mx + c), hence the // coefficients of a linear function are all that is required. // Coefficients came from numerical experiments. constexpr float kExpCubicFactor1 = 0.0780252018f; constexpr float kExpCubicFactor0 = 0.304684167f; // Coefficients are optimized to minimize the absolute error on // tanh = (e^2x - 1) / (e^2x + 1) instead of on pure e^x. // Enum that determines how a transcendental is computed. enum TranscendentalMode { // Cubic using 16 bit integer arithmetic. TM_ORDER3_16BIT, // Quartic using 16 bit integer arithmetic. TM_ORDER4_16BIT, // Quartic using 32 bit float arithmetic. TM_ORDER4_FLOAT, }; inline int FloatAsInt16(float x) { return static_cast(x * (1 << 15) + 0.5f); } inline int FloatAsInt32(float x) { return static_cast(x * (1 << 30) + 0.5f); } #if defined __ARM_NEON || defined __aarch64__ constexpr int kMaxSigmoidInputInt = static_cast(kMaxSigmoidInput); // Computes and returns 2^(x>>23) ie 2^u where x = u << 23 bits. // Uses the quartic floating point exponent trick, see COMMON DOCUMENTATION FOR // THE FLOATING EXPONENT TRICK above for details. // Returns the true value, ie not scaled. inline float32x4_t float32_pow2(float32x4_t x) { // The input is already shifted left by 23 bits, so when we convert to int, // the bottom 23 bits are the fractional part, and the top bits are the // integer part. We want to compute a function of the fractional part, so // we will mask it off and manipulate it. int32x4_t exp_int_x = vcvtq_s32_f32(x); // Mask to allow conversion of just the fractional part of x to fixed16<0>. int32x4_t mantissa_mask16 = vdupq_n_s32(0x7fff00); // Mask to allow conversion of just the fractional part of x to fixed32<1>. int32x4_t mantissa_mask32 = vdupq_n_s32(0x7fffff); // Narrowing shift to convert to fixed16<0>. int16x4_t x_16 = vshrn_n_s32(vandq_s32(mantissa_mask16, exp_int_x), 8); // Shift to convert to fixed32<1>. int32x4_t x_32 = vshlq_n_s32(vandq_s32(mantissa_mask32, exp_int_x), 7); // Compute the polynomial x(x - 1)(ax^2 + bx + c) of the fractional part. // Ordering these lines carefully makes it faster, as some of the multiply // operations can pipeline instead of waiting for the previous result. int32x4_t x_squared = vmull_s16(x_16, x_16); int16x4_t b = vdup_n_s16(FloatAsInt16(kExpQuarticFactor1)); int32x4_t c = vdupq_n_s32(FloatAsInt32(kExpQuarticFactor0)); int32x4_t bx_plus_c = vmlal_s16(c, b, x_16); int16x4_t a = vdup_n_s16(FloatAsInt16(kExpQuarticFactor2)); // Finish the quadratic: result = ax^2 + bx + c. int32x4_t result = vmlal_s16(bx_plus_c, a, vshrn_n_s32(x_squared, 15)); int32x4_t x_squared_minus_x = vsubq_s32(x_squared, x_32); // Multiply by x^2 - x. result = vqrdmulhq_s32(result, x_squared_minus_x); // Shift back to mantissa position. vqrdmulhq_s32 took 2x 30-mantissa bit // inputs, made 60-mantissa bit result, doubled it to 61 bits, then discarded // the bottom 32 making 29, so shift right 6 to get 23. result = vshrq_n_s32(result, 6); // Add the constant to normalize the exponent for IEEE format. int32x4_t exp_offset = vdupq_n_s32(kFloatExponentOffset); exp_int_x = vaddq_s32(exp_int_x, exp_offset); exp_int_x = vaddq_s32(exp_int_x, result); // Cast back to float, as we just computed the exponent and mantissa and // assembled them in IEEE format. return vreinterpretq_f32_s32(exp_int_x); } // Scaled float to float exp approximation, using a quartic refinement of // the exponent trick. See COMMON DOCUMENTATION FOR THE FLOATING EXPONENT TRICK // above for details. Input is a fixed32<31 - mantissa_bits> that has been // converted to a float without any further shifting. MUST HAVE ALREADY BEEN // CLIPPED to a suitable range for exp! // Returns a vector of standard unscaled floats. inline float32x4_t fixed32_exp_float_preclipped(const int mantissa_bits, float32x4_t x) { // Divide by log 2 to convert problem to 2^x, and scale to match the // mantissa bits required by IEEE floats. // This is the shift of the FP mantissa relative to the input mantissa. const int kXShift = kFloatMantissaBits - mantissa_bits; const float kLogFactor = static_cast(1 << kXShift); float32x4_t factor = vdupq_n_f32(kLogFactor * kOneOverLog2); float32x4_t y = vmulq_f32(x, factor); // Now compute 2^x. return float32_pow2(y); } // uses trick that 2^x can be computed by shifting integer into the // exponent, see the following reference for a derivation using double: // goo.gl/aUVTK3 // Input x is clamped to [-64, 64], even infinity and NaN. // Accurate to within 3% relative across the entire range. // Fully pipelined throughput is about 10 cycles per fast_exp call. inline float32x4_t fast_exp(float32x4_t x) { #if defined FAST_TRANSCENDENTALS && __ARM_ARCH >= 800 // Uses vcvtnq_s32_f32, not available on ARM v7 NEON. // Load A and B, which are defined as integers into float registers. float32x4_t A = vreinterpretq_f32_u32(vdupq_n_u32(kAConstant)); float32x4_t res = vreinterpretq_f32_u32(vdupq_n_u32(kBConstant)); // Make sure x within the allowed range. x = vminq_f32(x, vdupq_n_f32(kMaxExpInput)); x = vmaxq_f32(x, vdupq_n_f32(kMinExpInput)); // res = A * x + B. // This shifts x into the exponent field and adds the bias. res = vmlaq_f32(res, A, x); // Convert back to an integer, this is what uses the floating point // unit to compute 2^x. int32x4_t x_int = vcvtnq_s32_f32(res); return vreinterpretq_f32_s32(x_int); #else float32x4_t return_val = vdupq_n_f32(0.f); float exponent = expf(vgetq_lane_f32(x, 0)); return_val = vld1q_lane_f32(&exponent, return_val, 0); exponent = expf(vgetq_lane_f32(x, 1)); return_val = vld1q_lane_f32(&exponent, return_val, 1); exponent = expf(vgetq_lane_f32(x, 2)); return_val = vld1q_lane_f32(&exponent, return_val, 2); exponent = expf(vgetq_lane_f32(x, 3)); return_val = vld1q_lane_f32(&exponent, return_val, 3); return return_val; #endif // FAST_TRANSCENDENTALS } // This version does a conversion of the input to floating point, then calls // the floating point fast_exp function. There is another version // fast_exp_fixed, that never does a conversion and is less accurate, but much // faster. template inline float32x4_t fast_exp(int32x4_t x) { return fast_exp(vcvtq_n_f32_s32(x, 31 - ExponentBits)); } // Performs an exp estimate without doing any floating point operations. The // result is a floating point number. See scalar version for an explanation. template inline float32x4_t fast_exp_fixed(int32x4_t x) { static_assert(ExponentBits > 8, "Must have more than 8 ExponentBits"); constexpr int kA = 1.4426950408889634 * (1 << (ExponentBits - 8)); constexpr int kB = (127 << 23) - 366000; constexpr int maxInput = 80 << (31 - ExponentBits); constexpr int minInput = -maxInput; int32x4_t A = vdupq_n_s32(kA); int32x4_t res = vdupq_n_s32(kB); // Make sure x within the allowed range. x = vminq_s32(x, vdupq_n_s32(maxInput)); x = vmaxq_s32(x, vdupq_n_s32(minInput)); // res = A * x + B. // This shifts x into the exponent field and adds the bias. res = vmlaq_s32(res, A, x); return vreinterpretq_f32_s32(res); } // fast_exp_norange_check uses vcvtnq_s32_f32, not available on ARM v7 NEON. #if __ARM_ARCH >= 800 namespace detail { // tanh can do range check once. // Input x is clamped to [-64, 64], even infinity and NaN. inline float32x4_t fast_exp_norange_check(float32x4_t x) { float32x4_t A = vreinterpretq_f32_u32(vdupq_n_u32(kAConstant)); float32x4_t res = vreinterpretq_f32_u32(vdupq_n_u32(kBConstant)); res = vmlaq_f32(res, A, x); int32x4_t x_int = vcvtnq_s32_f32(res); return vreinterpretq_f32_s32(x_int); } } // namespace detail #endif // __ARM_ARCH >= 800 // Clips float input to [-kLimit,kLimit]. inline float32x4_t ClipToFloatBounds(const float kLimit, const float32x4_t x) { // Clip to the input bounds for this approximation. float32x4_t clip_limit = vdupq_n_f32(kLimit); float32x4_t clipped_x = vminq_f32(x, clip_limit); clip_limit = vnegq_f32(clip_limit); return vmaxq_f32(clipped_x, clip_limit); } inline float32x4_t float_tanh_float(const float32x4_t& x) { float32x4_t clipped_x = ClipToFloatBounds(kMaxTanhInput, x); // Divide by log 2 to convert problem to 2^x, double (as we need exp(2x)) and // scale to the mantissa bits required by float32_pow2 all in one multiply. // Add one to double the input. const float kLogFactor = static_cast(1 << (kFloatMantissaBits + 1)); float32x4_t factor = vdupq_n_f32(kLogFactor * kOneOverLog2); clipped_x = vmulq_f32(clipped_x, factor); // Now compute 2^x. float32x4_t exp_result = float32_pow2(clipped_x); // Now compute tanh using (e^2x - 1) / (e^2x + 1). float32x4_t one = vdupq_n_f32(1.0f); float32x4_t numerator = vsubq_f32(exp_result, one); float32x4_t denominator = vaddq_f32(exp_result, one); float32x4_t recp = vrecpeq_f32(denominator); // Newton-Raphson iteration, accuracy is important for audio quality recp = vmulq_f32(recp, vrecpsq_f32(recp, denominator)); recp = vmulq_f32(recp, numerator); // Compute 3rd-order Taylor tanh ~ x - x^3/3 for high accuracy and thus low // relative error close to 0. float32x4_t third = vdupq_n_f32(1.0f / 3.0f); float32x4_t taylor = vmulq_f32(x, x); taylor = vmulq_f32(taylor, x); taylor = vmulq_f32(taylor, third); taylor = vsubq_f32(x, taylor); // Test |x| <= 1/9, roughly where the errors cross over, without needing yet // another constant. float32x4_t ninth = vmulq_f32(third, third); uint32x4_t cmp_results = vcaleq_f32(x, ninth); return vbslq_f32(cmp_results, taylor, recp); } // Calculates (exp(x) - exp(-x)) / (exp(x) + exp(-x)). // Input x is clamped to [-9, 9], even infinity and NaN. // See test program for bounds. Throughput of FAST is 334 Mega/sec, // throughput of accurate is 232 Mega/sec. inline float32x4_t fast_tanh(float32x4_t x) { #if defined FASTER_TRANSCENDENTALS return float_tanh_float(x); #elif defined ACCURATE_TRANSCENDENTAL_APPROX && defined FAST_TRANSCENDENTALS x = vminq_f32(x, vdupq_n_f32(kMaxTanhInput)); x = vmaxq_f32(x, vdupq_n_f32(kMinTanhInput)); // The monomial coefficients of the numerator polynomial (odd). const float32x4_t alpha_1 = vdupq_n_f32(kTanhAlpha1); const float32x4_t alpha_3 = vdupq_n_f32(kTanhAlpha3); const float32x4_t alpha_5 = vdupq_n_f32(kTanhAlpha5); const float32x4_t alpha_7 = vdupq_n_f32(kTanhAlpha7); const float32x4_t alpha_9 = vdupq_n_f32(kTanhAlpha9); const float32x4_t alpha_11 = vdupq_n_f32(kTanhAlpha11); const float32x4_t alpha_13 = vdupq_n_f32(kTanhAlpha13); // The monomial coefficients of the denominator polynomial (even). const float32x4_t beta_0 = vdupq_n_f32(kTanhBeta0); const float32x4_t beta_2 = vdupq_n_f32(kTanhBeta2); const float32x4_t beta_4 = vdupq_n_f32(kTanhBeta4); const float32x4_t beta_6 = vdupq_n_f32(kTanhBeta6); // Since the polynomials are odd/even, we need x^2. const float32x4_t x2 = vmulq_f32(x, x); // Evaluate the numerator polynomial |p|. float32x4_t p = vmlaq_f32(alpha_11, x2, alpha_13); p = vmlaq_f32(alpha_9, x2, p); p = vmlaq_f32(alpha_7, x2, p); p = vmlaq_f32(alpha_5, x2, p); p = vmlaq_f32(alpha_3, x2, p); p = vmlaq_f32(alpha_1, x2, p); p = vmulq_f32(x, p); // Evaluate the denominator polynomial p. float32x4_t q = vmlaq_f32(beta_4, x2, beta_6); q = vmlaq_f32(beta_2, x2, q); q = vmlaq_f32(beta_0, x2, q); // Divide the numerator by the denominator. float32x4_t recp = vrecpeq_f32(q); recp = vmulq_f32(recp, vrecpsq_f32(recp, q)); return vmulq_f32(p, recp); #elif defined FAST_TRANSCENDENTALS && __ARM_ARCH >= 800 // Uses vcvtnq_s32_f32, not available on ARM v7 NEON. x = vminq_f32(x, vdupq_n_f32(kMaxTanhInput)); x = vmaxq_f32(x, vdupq_n_f32(kMinTanhInput)); float32x4_t exp_est = detail::fast_exp_norange_check(x); float32x4_t neg_exp_est = detail::fast_exp_norange_check(-x); // If we're in the linear region. // caleq = compare absolute <= uint32x4_t cmp_results = vcaleq_f32(x, vdupq_n_f32(kTanhLinearRegion)); float32x4_t diff = vsubq_f32(exp_est, neg_exp_est); float32x4_t sum = vaddq_f32(exp_est, neg_exp_est); float32x4_t recp = vrecpeq_f32(sum); recp = vmulq_f32(recp, vrecpsq_f32(recp, sum)); float32x4_t tanh_estimate = vmulq_f32(diff, recp); // Based on comparison, possibly copy x through instead of calculated value. // TODO(b/191497441): Is the compiler generating VBIT or VBSL ? VBIT is one // cycle and VBSL is two... documentation suggests it can do either. return vbslq_f32(cmp_results, x, tanh_estimate); #else float32x4_t return_val = vdupq_n_f32(0.f); float tanh_value = tanhf(vgetq_lane_f32(x, 0)); return_val = vld1q_lane_f32(&tanh_value, return_val, 0); tanh_value = tanhf(vgetq_lane_f32(x, 1)); return_val = vld1q_lane_f32(&tanh_value, return_val, 1); tanh_value = tanhf(vgetq_lane_f32(x, 2)); return_val = vld1q_lane_f32(&tanh_value, return_val, 2); tanh_value = tanhf(vgetq_lane_f32(x, 3)); return_val = vld1q_lane_f32(&tanh_value, return_val, 3); return return_val; #endif // FAST_TRANSCENDENTALS } // Input x is clamped to [-18, 18], even infinity and NaN. // See tests for error bounds. Using SIGMOID_AS_TANH with // ACCURATE_TRANSCENDENTAL_APPROX is both faster and more accurate. Using // SIGMOID_AS_TANH with just FAST is slower, but more accurate. // SIGMOID_AS_TANH, ACCURATE is 205 Mega/sec // SIGMOID_AS_TANH, FAST is 290 Mega/sec // FAST is 340 Mega/sec inline float32x4_t fast_sigmoid(float32x4_t x) { #ifdef SIGMOID_AS_TANH float32x4_t half = vdupq_n_f32(0.5f); return vmlaq_f32(half, half, fast_tanh(vmulq_f32(half, x))); #else // SIGMOID_AS_TANH #if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX x = vminq_f32(x, vdupq_n_f32(kMaxSigmoidInput)); x = vmaxq_f32(x, vdupq_n_f32(kMinSigmoidInput)); // The monomial coefficients of the numerator polynomial (odd). const float32x4_t alpha_1 = vdupq_n_f32(kSigmoidAlpha1); const float32x4_t alpha_3 = vdupq_n_f32(kSigmoidAlpha3); const float32x4_t alpha_5 = vdupq_n_f32(kSigmoidAlpha5); const float32x4_t alpha_7 = vdupq_n_f32(kSigmoidAlpha7); const float32x4_t alpha_9 = vdupq_n_f32(kSigmoidAlpha9); // The monomial coefficients of the denominator polynomial (even). const float32x4_t beta_0 = vdupq_n_f32(kSigmoidBeta0); const float32x4_t beta_2 = vdupq_n_f32(kSigmoidBeta2); const float32x4_t beta_4 = vdupq_n_f32(kSigmoidBeta4); const float32x4_t beta_6 = vdupq_n_f32(kSigmoidBeta6); const float32x4_t beta_8 = vdupq_n_f32(kSigmoidBeta8); const float32x4_t beta_10 = vdupq_n_f32(kSigmoidBeta10); // Since the polynomials are odd/even, we need x^2. const float32x4_t x2 = vmulq_f32(x, x); // Evaluate the numerator polynomial p. float32x4_t p = vmlaq_f32(alpha_7, x2, alpha_9); p = vmlaq_f32(alpha_5, x2, p); p = vmlaq_f32(alpha_3, x2, p); p = vmlaq_f32(alpha_1, x2, p); p = vmulq_f32(x, p); // Evaluate the denominator polynomial p. float32x4_t q = vmlaq_f32(beta_8, x2, beta_10); q = vmlaq_f32(beta_6, x2, q); q = vmlaq_f32(beta_4, x2, q); q = vmlaq_f32(beta_2, x2, q); q = vmlaq_f32(beta_0, x2, q); // Divide the numerator by the denominator. float32x4_t recp = vrecpeq_f32(q); recp = vmulq_f32(recp, vrecpsq_f32(recp, q)); return vmlaq_f32(vdupq_n_f32(0.5f), p, recp); #elif defined FAST_TRANSCENDENTALS float32x4_t denom = vaddq_f32(fast_exp(vnegq_f32(x)), vdupq_n_f32(1.f)); float32x4_t recp = vrecpeq_f32(denom); // Newton-Raphson iteration, accuracy is important for audio quality. recp = vmulq_f32(recp, vrecpsq_f32(recp, denom)); float32x4_t half = vdupq_n_f32(0.5f); float32x4_t quarter = vdupq_n_f32(0.245f); float32x4_t linear_approx = vmlaq_f32(half, quarter, x); uint32x4_t cmp_results = vcaleq_f32(x, vdupq_n_f32(kSigmoidLinearRegion)); return vbslq_f32(cmp_results, linear_approx, recp); #else float32x4_t return_val = vdupq_n_f32(0.f); float result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 0))); return_val = vld1q_lane_f32(&result, return_val, 0); result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 1))); return_val = vld1q_lane_f32(&result, return_val, 1); result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 2))); return_val = vld1q_lane_f32(&result, return_val, 2); result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 3))); return_val = vld1q_lane_f32(&result, return_val, 3); return return_val; #endif // FAST_TRANSCENDENTALS #endif // SIGMOID_AS_TANH } // Scalar implementations, mainly useful for testing. inline float fast_exp(float x) { return vgetq_lane_f32(fast_exp(vdupq_n_f32(x)), 0); } template inline float fast_exp(fixed32 x) { return vgetq_lane_f32(fast_exp(vdupq_n_s32(x.raw_val())), 0); } // Returns the exponent of a fixed point number in floating point without ever // doing any conversions. Less accurate than the version that does conversions, // but still accurate to within 4% relative for x < 16. template inline float fast_exp_fixed(fixed32 x) { return vgetq_lane_f32(fast_exp_fixed(vdupq_n_s32(x.raw_val())), 0); } inline float fast_sigmoid(float x) { return vgetq_lane_f32(fast_sigmoid(vdupq_n_f32(x)), 0); } inline float fast_tanh(float x) { return vgetq_lane_f32(fast_tanh(vdupq_n_f32(x)), 0); } // Clips integer input to [-|kLimit|, |kLimit|]. // Input: register containins 4x fixed32 with mantissa_bits. // Output: register containing 4x fixed32 limited to // [-|kLimit| << |mantissa_bits|, |kLimit| << |mantissa_bits|]. template inline int32x4_t ClipToBounds(const int mantissa_bits, const int32x4_t x) { // Clip to the input bounds for this approximation. int32x4_t clip_limit = vdupq_n_s32(-(kLimit << mantissa_bits)); int32x4_t clipped_x = vmaxq_s32(x, clip_limit); clip_limit = vnegq_s32(clip_limit); return vminq_s32(clipped_x, clip_limit); } // Fixed32 sigmoid approximation via a quadratic refinement of the exponent // trick. // Input: Register containing 4x fixed32 with |mantissa_bits|. // Output: Register containing 4x float results. inline float32x4_t fixed32_sigmoid_float(const int mantissa_bits, const int32x4_t x) { int32x4_t input = vnegq_s32(x); float32x4_t y = vcvtq_f32_s32(ClipToBounds(mantissa_bits, input)); y = fixed32_exp_float_preclipped(mantissa_bits, y); float32x4_t one = vdupq_n_f32(1.0f); // Approximate reciprocal is not accurate enough - use full division. float32x4_t denom = vaddq_f32(y, one); float32x4_t recp = vrecpeq_f32(denom); // Newton-Raphson iteration, accuracy is important for audio quality recp = vmulq_f32(recp, vrecpsq_f32(recp, denom)); return recp; } template inline float32x4_t fast_sigmoid(int32x4_t x) { #if defined FASTER_TRANSCENDENTALS // Computation will fail to produce the right result if the input mantissa // bits exceeds the number in a float. static_assert(kFloatMantissaBits >= fixed32::kMantissaBits, "Mantissa bits must be at most 23!"); return fixed32_sigmoid_float(fixed32::kMantissaBits, x); #else return fast_sigmoid(vcvtq_n_f32_s32(x, fixed32::kMantissaBits)); #endif // FASTER_TRANSCENDENTALS } template inline float fast_sigmoid(fixed32 x) { return vgetq_lane_f32(fast_sigmoid(vdupq_n_s32(x.raw_val())), 0); } #else // defined __ARM_NEON || defined __aarch64__ inline float fast_exp(float x) { #ifdef FAST_TRANSCENDENTALS if (isnan(x)) return 0.0f; x = std::max(std::min(x, kMaxExpInput), kMinExpInput); float AConstant, BConstant; memcpy(&AConstant, &kAConstant, sizeof(int)); memcpy(&BConstant, &kBConstant, sizeof(int)); float y = x * AConstant + BConstant; int x_int = static_cast(y); float ret; memcpy(&ret, &x_int, sizeof(float)); return ret; #else return expf(x); #endif // FAST_TRANSCENDENTALS } template inline float fast_exp(fixed32 x) { return fast_exp(static_cast(x)); } template inline float fast_exp_fixed(fixed32 x) { static_assert(ExponentBits > 8, "Must have more than 8 ExponentBits"); int matched_decimal = std::max(std::min(x.raw_val(), (80 << (31 - ExponentBits))), -(80 << (31 - ExponentBits))); // Convert 1 / log(2) to 16-bit fixed point with 1 exponent bit // (1 / log(2)) * (1 << 14), but then right shift by the appropriate amount to // line the decimal point up with the 32-bit float representation. // (MantissaBits of x) + (MantissaBits of constant) = 23 // 23 - (MantissaBits of x) = MantissaBits of constant // 23 - (31 - ExponentBits of x) = ... // (ExponentBits of x - 8) = MantissaBits of constant const int16_t A = (1.f / logf(2.f)) * (1 << (ExponentBits - 8)); // Same rationale as for floating point versions, bias exponent, subtract // 366000 to reduce error by centering approximation, instead of being // one-sided. const int B = (127 << 23) - 366000; matched_decimal = A * matched_decimal + B; float ret_val; memcpy(&ret_val, &matched_decimal, sizeof(float)); return ret_val; } inline float fast_tanh(float x) { #if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX // Doesn't do anything fancy, just a 13/6-degree rational interpolant which // is accurate up to a couple of ulp in the range [-9, 9], outside of which // fl(tanh(x)) = +/-1. x = std::max(std::min(x, kMaxTanhInput), kMinTanhInput); // Since the polynomials are odd/even, we need x^2. float x2 = x * x; // Evaluate numerator. float p = kTanhAlpha11 + x2 * kTanhAlpha13; p = kTanhAlpha9 + x2 * p; p = kTanhAlpha7 + x2 * p; p = kTanhAlpha5 + x2 * p; p = kTanhAlpha3 + x2 * p; p = kTanhAlpha1 + x2 * p; p = x * p; // Evaluate denominator. float q = kTanhBeta4 + x2 * kTanhBeta6; q = kTanhBeta2 + x2 * q; q = kTanhBeta0 + x2 * q; return p / q; #elif defined FAST_TRANSCENDENTALS if (std::abs(x) < kTanhLinearRegion) { return x; } else { x = std::max(std::min(x, kMaxTanhInput), kMinTanhInput); float positive = fast_exp(x); float negative = fast_exp(-x); return (positive - negative) / (positive + negative); } #else return tanhf(x); #endif // FAST_TRANSCENDENTALS } inline float fast_sigmoid(float x) { #ifdef SIGMOID_AS_TANH return .5f * fast_tanh(.5f * x) + .5f; #else #if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX // Doesn't do anything fancy, just a 9/10-degree rational interpolant which // interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulp in the range // [-18, 18], outside of which the fl(sigmoid(x)) = {0|1}. The shifted // sigmoid is interpolated because it was easier to make the fit converge. // See GenericPacketMath.h* in the open source Eigen library. x = std::max(std::min(x, kMaxSigmoidInput), kMinSigmoidInput); // Since the polynomials are odd/even, we need x^2. float x2 = x * x; // Evaluate numerator. float p = kSigmoidAlpha7 + x2 * kSigmoidAlpha9; p = kSigmoidAlpha5 + x2 * p; p = kSigmoidAlpha3 + x2 * p; p = kSigmoidAlpha1 + x2 * p; p = x * p; // Evaluate denominator. float q = kSigmoidBeta8 + x2 * kSigmoidBeta10; q = kSigmoidBeta6 + x2 * q; q = kSigmoidBeta4 + x2 * q; q = kSigmoidBeta2 + x2 * q; q = kSigmoidBeta0 + x2 * q; return p / q + 0.5f; #elif defined FAST_TRANSCENDENTALS if (std::abs(x) < kSigmoidLinearRegion) { return .245 * x + .5; } else { return 1.f / (1.f + fast_exp(-x)); } #else return 1.f / (1.f + expf(-x)); #endif // FAST_TRANSCENDENTALS #endif // SIGMOID_AS_TANH } template inline float fast_sigmoid(fixed32 x) { return fast_sigmoid(static_cast(x)); } #endif // defined __aarch64__ // Number of exponent bits to use for tanh. static constexpr int kNumTanhExpBits = 3; // Number of exponent bits to use for sigmoid. static constexpr int kNumSigmoidExpBits = 4; // Number of extra bits to shift sigmoid, due to its low gradient. static constexpr int kNumExtraSigmoidShiftBits = 1; // Returns (and builds if not done yet) a static data table (that is never // deleted, as per the style guide) that implements tanh on fixed32 input, // returning another fixed32 with the given number of mantissa bits (which is // assumed to be less than the input mantissa bits). // NOTE that this function is intended to be used only with fixed16 outputs that // are sign-extended to 32 bits for convenience, and will return a nullptr // if asked for more than |kMaxMantissaBits| of precision in the output table. const int* TanhTable(int num_mantissa_bits_out); // As TanhTable, but for Sigmoid. const int* SigmoidTable(int num_mantissa_bits_out); // Scalar/generic function to compute and return the fast approximation to exp // via a polynomial refinement of the floating point exponent trick. // TM_ORDER4_16BIT:Max relative error < 5e-6, absolute error < 1e-5 for x < 1. // TM_ORDER3_16BIT:Max relative error < 1.1e-4, absolute error < 3e-4 for x // < 1. template float fixed32_exp(fixed32 x) { constexpr int kMantissaBits = MantissaBitsOf>::value; // Clip x to min/max exp input to avoid infinities. int64_t clipped_x = std::max(std::min(x.raw_val(), kMaxExpInputInt << kMantissaBits), -(kMaxExpInputInt << kMantissaBits)); // First convert problem from e^x to 2^x by multiplying by 1/log(2). // To maximize precision, log_factor is shifted left the maximum amount to // keep within int32, and we shift x left a further amount such that the // binary point of the product sits in the correct place in the top 32 bits of // the result to be used directly as a float. We can't do that directly, as x // would overflow, so we have to shift by 1 bit less and shift the result by // 1 bit less to match. constexpr int kXShift = kFloatMantissaBits + 31 - kMaxLog2Shift - kMantissaBits; static_assert(kXShift >= 0, "Mantissa bits > kFloatMantissaBits + 31 - kMaxLog2Shift"); clipped_x <<= kXShift; int float_as_int = (kLogFactor * clipped_x >> 31) + kFloatExponentOffset; // Separate the resulting fixed-point into integer and fractional parts. int int_part = float_as_int & kFloatExponentMask; int float_part = float_as_int & kFloatMantissaMask; float fraction = static_cast(float_part) / (1 << kFloatMantissaBits); // Compute the mantissa = 2^fraction using: // fraction - fraction*(1-fraction)*(polynomial of fraction) // This guarantees exactness at 0 and 1, providing continuity of the error at // integer boundaries. float mantissa; if (kOrder == TM_ORDER4_16BIT || kOrder == TM_ORDER4_FLOAT) { mantissa = (kExpQuarticFactor2 * fraction + kExpQuarticFactor1) * fraction + kExpQuarticFactor0; } else if (kOrder == TM_ORDER3_16BIT) { mantissa = kExpCubicFactor1 * fraction + kExpCubicFactor0; } mantissa = fraction - fraction * (1.0f - fraction) * mantissa; // Since the function above guarantees to stay within [0, 1), we could do all // the above in fixed point if necessary, in which case, we can just stuff // the bottom kFloatMantissaBits in with the exponent and we are done. // In the floating point world, it is simpler to just multiply them together. float result; memcpy(&result, &int_part, sizeof(float)); return result * (1.0f + mantissa); } // Computes and returns tanh(x) fixed32->float using a polynomial refinement of // the floating point exponent trick. // kOrder=4: Absolute error < 1.8e-6. Relative error < 1.2e-4 for |x| > 0.01. // kOrder=3: Absolute error < 6e-5. Relative error < 3e-3 for |x| > 0.01 template float fixed32_tanh(fixed32 x) { float float_x = static_cast(x); if (std::abs(float_x) < 1.0f / 9.0f) { return float_x * (1 - float_x * float_x / 3.0f); } x = static_cast>(x.raw_val() * 2); float exp_2x = fixed32_exp(x); return (exp_2x - 1.0f) / (exp_2x + 1.0f); } // Computes and returns sigmoid(x) fixed32->float using a polynomial refinement // of the floating point exponent trick. // TM_ORDER4_16BIT: Absolute error < 9e-7, relative < 4e-6. // TM_ORDER3_16BIT: Absolute error < 3e-5, relative < 1.1e-4. template float fixed32_sigmoid(fixed32 x) { x = static_cast>(-x.raw_val()); float exp_x = fixed32_exp(x); return 1.0f / (exp_x + 1.0f); } #if defined __AVX2__ // Inline function to access an int32 data table by shifting |x| right by // |kNumShiftBits|, and adding |kTableOffset| to the result. |x| contains 8 // indices and 8 results are returned. The data table is of size // |kTableOffset| * 2 + 1. template inline __m256i index_data_table(const int32_t* data_table, const __m256i& x) { // Shift right with rounding to match input and output precision. __m256i shifted = _mm256_set1_epi32(1 << (kNumShiftBits - 1)); shifted = _mm256_add_epi32(x, shifted); shifted = _mm256_srai_epi32(shifted, kNumShiftBits); // Add the offset. __m256i addend = _mm256_set1_epi32(kTableOffset); shifted = _mm256_add_epi32(shifted, addend); // And clamp to the indices of the LUT. addend = _mm256_add_epi32(addend, addend); shifted = _mm256_min_epi32(shifted, addend); shifted = _mm256_max_epi32(shifted, _mm256_setzero_si256()); // Lookup the results in the table. return _mm256_i32gather_epi32(data_table, shifted, 4); } // Fixed32 to fixed16-in-an-int32 tanh LUT function. // Input: register containins 8x fixed32 with |NumInputMantissaBits|. // Output: a register containing 8x fixed16 with |NumOutputMantissaBits|, but // note that they are sign-extended to 32 bits and are therefore basically the // same as fixed32 with |NumOutputMantissaBits|. template inline __m256i fixed32_tanh_fixed16(const int* tanh_table, const __m256i& x) { // Lose the unnecessary input precision. constexpr int kNumShiftBits = NumInputMantissaBits - NumOutputMantissaBits; constexpr int kTableOffset = 1 << (NumOutputMantissaBits + kNumTanhExpBits); return index_data_table(tanh_table, x); } // Fixed32 to fixed16-in-an-int32 sigmoid LUT function. // Input: register containins 8x fixed32 with |NumInputMantissaBits|. // Output: a register containing 8x fixed16 with |NumOutputMantissaBits|, but // note that they are sign-extended to 32 bits and are therefore basically the // same as fixed32 with |NumOutputMantissaBits|. template inline __m256i fixed32_sigmoid_fixed16(const int* sigmoid_table, const __m256i& x) { // Lose the unnecessary input precision. constexpr int kNumShiftBits = kNumExtraSigmoidShiftBits + NumInputMantissaBits - NumOutputMantissaBits; constexpr int kTableOffset = 1 << (NumOutputMantissaBits + kNumSigmoidExpBits - kNumExtraSigmoidShiftBits); return index_data_table(sigmoid_table, x); } // Convert 2x registers of 8x float32 into 1 register of 16x16 bit fixed int, // assuming that the floats are already scaled up. inline __m256i PackFloatsToFixed16(const __m256& x0, const __m256& x1) { __m256i int0 = _mm256_cvtps_epi32(x0); __m256i int1 = _mm256_cvtps_epi32(x1); int0 = _mm256_packs_epi32(int0, int1); // Swap the middle 64 bit elements so the results are in the right order. return _mm256_permute4x64_epi64(int0, 0xd8); } // Clips integer input to [-|kLimit|, |kLimit|]. // Input: register containins 8x fixed32 with |mantissa_bits|. // Output: register containing 8x fixed32 limited to // [-|kLimit| << |mantissa_bits|, |kLimit| << |mantissa_bits|]. template inline __m256i ClipToBounds(const int mantissa_bits, const __m256i& x) { // Clip to the input bounds for this approximation. __m256i clip_limit = _mm256_set1_epi32(-(kLimit << mantissa_bits)); __m256i clipped_x = _mm256_max_epi32(x, clip_limit); // This quickly negates the limit without having to load another constant. clip_limit = _mm256_sign_epi32(clip_limit, clip_limit); return _mm256_min_epi32(clipped_x, clip_limit); } // Clips float input to [-|kLimit|, |kLimit|]. // Input: register containins 8x float. // Output: register containing 8x float limited to [-|kLimit|, |kLimit|]. inline __m256 ClipToFloatBounds(const float kLimit, const __m256& x) { __m256 clip_limit = _mm256_set1_ps(kLimit); __m256 clipped_x = _mm256_min_ps(x, clip_limit); clip_limit = _mm256_set1_ps(-kLimit); return _mm256_max_ps(clipped_x, clip_limit); } // Float to float power of 2 approximation, using a quartic refinement of // the exponent trick. For TM_ORDER4_16BIT and TM_ORDER3_16BIT, implementation // is entirely in integer, using 16x16=16 multiplication, using AVX2, which // enables 16 elements to be computed in parallel, hence the double register // input/output args. // The price paid for this speed is an increase in error over the (scalar) int32 // example implementations above by a variable factor of 4-10. // For the TM_ORDER4_FLOAT case, the computation is all done in float, solving // this lower precision problem. // NOTE: The input must have already been clipped to prevent overflow, which // sets the practical limit to +/-126 << kFloatMantissaBits. // NOTE: The input is a scaled float, as if converted raw from int, and the // scale factor is fixed at kFloatMantissaBits! // Input: 2x register containining 8x float * 1 << kFloatMantissaBits. // Output: 2x register containing 8x float. // TM_ORDER4_FLOAT: Max relative error < 8e-6, absolute error < 9e-6 for x < 1. // TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 6e-5 for x < 1. // TM_ORDER3_16BIT: Max relative error < 6e-4, absolute error < 2e-3 for x < 1. template inline void float32_pow2(__m256& x0, __m256& x1) { // Convert straight to int. __m256i exp_int_x0 = _mm256_cvtps_epi32(x0); __m256i exp_int_x1 = _mm256_cvtps_epi32(x1); __m256i result_x0, result_x1; static_assert(kOrder == TM_ORDER4_FLOAT || kOrder == TM_ORDER4_16BIT || kOrder == TM_ORDER3_16BIT, "Invalid order."); if (kOrder == TM_ORDER4_FLOAT) { __m256i mantissa_mask = _mm256_set1_epi32(0x7fffff); __m256 float_factor = _mm256_set1_ps(1.0f / static_cast(1 << kFloatMantissaBits)); __m256i fract0 = _mm256_and_si256(mantissa_mask, exp_int_x0); __m256i fract1 = _mm256_and_si256(mantissa_mask, exp_int_x1); __m256 float0 = _mm256_mul_ps(_mm256_cvtepi32_ps(fract0), float_factor); __m256 float1 = _mm256_mul_ps(_mm256_cvtepi32_ps(fract1), float_factor); // Compute the polynomial of the fractional part. // Ordering these lines carefully makes it faster, as some of the multiply // operations can pipeline instead of waiting for the previous result. __m256 x_squared0 = _mm256_mul_ps(float0, float0); __m256 x_squared1 = _mm256_mul_ps(float1, float1); __m256 b = _mm256_set1_ps(kExpQuarticFactor1); __m256 b_x0 = _mm256_mul_ps(b, float0); __m256 b_x1 = _mm256_mul_ps(b, float1); __m256 a = _mm256_set1_ps(kExpQuarticFactor2); __m256 a_x_squared0 = _mm256_mul_ps(a, x_squared0); __m256 a_x_squared1 = _mm256_mul_ps(a, x_squared1); __m256 x_squared_minus_x0 = _mm256_sub_ps(x_squared0, float0); __m256 x_squared_minus_x1 = _mm256_sub_ps(x_squared1, float1); __m256 c = _mm256_set1_ps(kExpQuarticFactor0); b_x0 = _mm256_add_ps(b_x0, c); b_x1 = _mm256_add_ps(b_x1, c); float_factor = _mm256_set1_ps(static_cast(1 << kFloatMantissaBits)); a_x_squared0 = _mm256_add_ps(a_x_squared0, b_x0); a_x_squared1 = _mm256_add_ps(a_x_squared1, b_x1); a_x_squared0 = _mm256_mul_ps(a_x_squared0, x_squared_minus_x0); a_x_squared1 = _mm256_mul_ps(a_x_squared1, x_squared_minus_x1); result_x0 = _mm256_cvtps_epi32(_mm256_mul_ps(a_x_squared0, float_factor)); result_x1 = _mm256_cvtps_epi32(_mm256_mul_ps(a_x_squared1, float_factor)); } else { // Combine the fractional part of both inputs into a single register. // The representation is fixed16<0>, ie 15 mantissa bits. __m256i mantissa_mask = _mm256_set1_epi32(0x7fff00); __m256i x_01 = _mm256_srli_epi32(_mm256_and_si256(mantissa_mask, exp_int_x0), 8); x_01 = _mm256_or_si256( x_01, _mm256_slli_epi32(_mm256_and_si256(mantissa_mask, exp_int_x1), 8)); // Compute the polynomial of the fractional part. // Ordering these lines carefully makes it faster, as some of the multiply // operations can pipeline instead of waiting for the previous result. __m256i x_squared = _mm256_mulhrs_epi16(x_01, x_01); __m256i result, x_squared_minus_x; if (kOrder == TM_ORDER4_16BIT) { __m256i b = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor1)); __m256i b_x = _mm256_mulhrs_epi16(b, x_01); __m256i a = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor2)); __m256i a_x_squared = _mm256_mulhrs_epi16(a, x_squared); x_squared_minus_x = _mm256_sub_epi16(x_squared, x_01); // LOG(INFO) << "x_squared_minus_x=" << // static_cast(_mm256_extract_epi16(x_squared_minus_x, 0)) / // 32768.0f; __m256i c = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor0)); b_x = _mm256_add_epi16(b_x, c); // LOG(INFO) << "bx+c=" << static_cast(_mm256_extract_epi16(b_x, // 0)) / 32768.0f; result = _mm256_add_epi16(a_x_squared, b_x); } else { // kOrder = TM_ORDER3_16BIT __m256i a = _mm256_set1_epi16(FloatAsInt16(kExpCubicFactor1)); __m256i b = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor0)); __m256i a_x = _mm256_mulhrs_epi16(a, x_01); x_squared_minus_x = _mm256_sub_epi16(x_squared, x_01); result = _mm256_add_epi16(a_x, b); } result = _mm256_mulhrs_epi16(result, x_squared_minus_x); // Extract 16x16-bit results back to the separate sets of 8x32. result_x0 = _mm256_slli_epi32(result, 16); result_x0 = _mm256_srai_epi32(result_x0, 8); result_x1 = _mm256_srai_epi32(result, 16); result_x1 = _mm256_slli_epi32(result_x1, 8); } // Add the constant to normalize the exponent. __m256i exp_offset = _mm256_set1_epi32(kFloatExponentOffset); exp_int_x0 = _mm256_add_epi32(exp_int_x0, exp_offset); exp_int_x0 = _mm256_add_epi32(exp_int_x0, result_x0); exp_int_x1 = _mm256_add_epi32(exp_int_x1, exp_offset); exp_int_x1 = _mm256_add_epi32(exp_int_x1, result_x1); // Cast back to float, as we just computed the exponent and mantissa and // assembled them in IEEE format. x0 = _mm256_castsi256_ps(exp_int_x0); x1 = _mm256_castsi256_ps(exp_int_x1); } // Fixed32 to to float exp approximation, using a quartic/cubic refinement of // the exponent trick. Implementation is entirely in integer, using 16x16=16 // multiplication, using AVX2, which enables 16 elements to be computed in // parallel, hence the double register input/output args. // The price paid for this speed is an increase in error over the (scalar) int32 // example implementations above by a variable factor of 4-10. // The TM_ORDER4_FLOAT version uses floats and improves the precision. // Input: 2x registers containins 8x fixed32 with kMantissaBits. // Output: 2x registers containing 8x float32. // TM_ORDER4_FLOAT: Max relative error < 8e-6, absolute error < 9e-6 for x < 1. // TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 6e-5 for x < 1. // TM_ORDER3_16BIT: Max relative error < 6e-4, absolute error < 2e-3 for x < 1. template inline void float_exp_float_preclipped(__m256& y0, __m256& y1) { // Divide by log 2 to convert problem to 2^x, and scale to match the // mantissa bits required by IEEE floats. Without a _mm256_mulhrs_epi32, it is // much easier to do this in float, even with the double conversion, as 16 bit // is not precise enough here. // This is the shift of the FP mantissa relative to the input mantissa. constexpr int kXShift = kFloatMantissaBits - kInputMantissaBits; constexpr float kLogFactor = static_cast(1 << kXShift); __m256 factor = _mm256_set1_ps(kLogFactor * kOneOverLog2); y0 = _mm256_mul_ps(y0, factor); y1 = _mm256_mul_ps(y1, factor); // Now compute 2^x. float32_pow2(y0, y1); } template inline void fixed32_exp_float(const __m256i& x0, const __m256i& x1, __m256& y0, __m256& y1) { // Clip to acceptable bounds to prevent overflow, and convert to float. y0 = _mm256_cvtepi32_ps(ClipToBounds(kInputMantissaBits, x0)); y1 = _mm256_cvtepi32_ps(ClipToBounds(kInputMantissaBits, x1)); float_exp_float_preclipped(y0, y1); } // Float->float tanh approximation via the exponent trick. // Note that the input is scaled floats, as if converted raw from fixed16/32. // Input: 2x registers containing 8x float scaled by input_mantissa_bits. // Output: two registers containing 8x float. // TM_ORDER4_FLOAT: Max relative error < 2.1e-5, absolute error < 2.3e-6. // TM_ORDER4_16BIT: Max relative error < 1e-4, absolute error < 1.3e-5. // TM_ORDER3_16BIT: Max relative error < 2.1e-3, absolute error < 3e-4. template inline void float_tanh_float(const __m256& x0, const __m256& x1, __m256& y0, __m256& y1) { // Divide by log 2 to convert problem to 2^x, double (as we need exp(2x)) and // scale to the mantissa bits required by float32_pow2 all in one multiply. // This is the shift of the FP mantissa relative to the input mantissa. // Add one to double the input. const float kLogFactor = static_cast(1 << (kFloatMantissaBits - kInputMantissaBits + 1)); __m256 factor = _mm256_set1_ps(kLogFactor * kOneOverLog2); // Clip to suitable input bounds for tanh. __m256 clip_limit = _mm256_set1_ps(kMaxTanhInput * (1 << kInputMantissaBits)); __m256 clip0 = _mm256_min_ps(x0, clip_limit); __m256 clip1 = _mm256_min_ps(x1, clip_limit); clip_limit = _mm256_set1_ps(-kMaxTanhInput * (1 << kInputMantissaBits)); clip0 = _mm256_max_ps(clip0, clip_limit); clip1 = _mm256_max_ps(clip1, clip_limit); __m256 exp0 = _mm256_mul_ps(clip0, factor); __m256 exp1 = _mm256_mul_ps(clip1, factor); // Now compute 2^x. float32_pow2(exp0, exp1); // Now compute tanh using (e^2x - 1) / (e^2x + 1). __m256 one = _mm256_set1_ps(1.0f); __m256 numerator = _mm256_sub_ps(exp0, one); __m256 denominator = _mm256_add_ps(exp0, one); // Approximate reciprocal is not accurate enough - use full division. exp0 = _mm256_div_ps(numerator, denominator); numerator = _mm256_sub_ps(exp1, one); denominator = _mm256_add_ps(exp1, one); exp1 = _mm256_div_ps(numerator, denominator); // Compute 3rd-order Taylor tanh ~ x - x^3/3 for high accuracy and thus low // relative error close to 0. // Normalize the inputs back to proper floats. factor = _mm256_set1_ps(1.0f / (1 << kInputMantissaBits)); clip0 = _mm256_mul_ps(clip0, factor); clip1 = _mm256_mul_ps(clip1, factor); __m256 third = _mm256_set1_ps(-1.0f / 3.0f); __m256 taylor0 = _mm256_mul_ps(clip0, clip0); __m256 taylor1 = _mm256_mul_ps(clip1, clip1); taylor0 = _mm256_mul_ps(taylor0, clip0); taylor1 = _mm256_mul_ps(taylor1, clip1); // TODO(b/191497441): The next two pairs of instructions could be combined to // _mm256_fmadd_ps, but requires -mfma compilation option, eg: // taylor0 = _mm256_fmadd_ps(taylor0, third, clip0); taylor0 = _mm256_mul_ps(taylor0, third); taylor1 = _mm256_mul_ps(taylor1, third); taylor0 = _mm256_add_ps(clip0, taylor0); taylor1 = _mm256_add_ps(clip1, taylor1); // Test |x| <= 1/9, roughly where the errors cross over, without needing yet // another constant. third = _mm256_mul_ps(third, third); __m256 neg_zero = _mm256_set1_ps(-0.0f); clip0 = _mm256_andnot_ps(neg_zero, clip0); clip1 = _mm256_andnot_ps(neg_zero, clip1); __m256 cmp_results0 = _mm256_cmp_ps(clip0, third, _CMP_LE_OQ); __m256 cmp_results1 = _mm256_cmp_ps(clip1, third, _CMP_LE_OQ); y0 = _mm256_blendv_ps(exp0, taylor0, cmp_results0); y1 = _mm256_blendv_ps(exp1, taylor1, cmp_results1); } // Fixed32 sigmoid approximation via the AVX2 implementation of the exponent // trick. // Input: 2x registers containins 8x float containing converted fixed32 scaled // with kInputMantissaBits. // Output: 2x registers containing 8x float. // TM_ORDER4_FLOAT: Max relative error < 4e-6, absolute error < 1e-6. // TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 7e-6. // TM_ORDER3_16BIT: Max relative error < 5.4e-4, absolute error < 1.4e-4. template inline void float_sigmoid_float(__m256& y0, __m256& y1) { constexpr float kInputFactor = static_cast(1 << kInputMantissaBits); // Negate the inputs. __m256 minus_zero = _mm256_set1_ps(-0.0f); y0 = _mm256_xor_ps(y0, minus_zero); y1 = _mm256_xor_ps(y1, minus_zero); y0 = ClipToFloatBounds(kMaxSigmoidInput * kInputFactor, y0); y1 = ClipToFloatBounds(kMaxSigmoidInput * kInputFactor, y1); float_exp_float_preclipped(y0, y1); __m256 one = _mm256_set1_ps(1.0f); // Approximate reciprocal is not accurate enough - use full division. y0 = _mm256_div_ps(one, _mm256_add_ps(y0, one)); y1 = _mm256_div_ps(one, _mm256_add_ps(y1, one)); } #endif // defined __AVX2__ } // namespace csrblocksparse #endif // LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FAST_TRANSCENDENTALS_H_