/****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once /** * \file * Utilities for interacting with the opaque CUDA __half type */ #include #include #include #include #ifdef __GNUC__ // There's a ton of type-punning going on in this file. #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" #endif /****************************************************************************** * half_t ******************************************************************************/ /** * Host-based fp16 data type compatible and convertible with __half */ struct half_t { uint16_t __x; /// Constructor from __half __host__ __device__ __forceinline__ half_t(const __half &other) { __x = reinterpret_cast(other); } /// Constructor from integer __host__ __device__ __forceinline__ half_t(int a) { *this = half_t(float(a)); } /// Default constructor __host__ __device__ __forceinline__ half_t() : __x(0) {} /// Constructor from float __host__ __device__ __forceinline__ half_t(float a) { // Stolen from Norbert Juffa uint32_t ia = *reinterpret_cast(&a); uint16_t ir; ir = (ia >> 16) & 0x8000; if ((ia & 0x7f800000) == 0x7f800000) { if ((ia & 0x7fffffff) == 0x7f800000) { ir |= 0x7c00; /* infinity */ } else { ir = 0x7fff; /* canonical NaN */ } } else if ((ia & 0x7f800000) >= 0x33000000) { int32_t shift = (int32_t) ((ia >> 23) & 0xff) - 127; if (shift > 15) { ir |= 0x7c00; /* infinity */ } else { ia = (ia & 0x007fffff) | 0x00800000; /* extract mantissa */ if (shift < -14) { /* denormal */ ir |= ia >> (-1 - shift); ia = ia << (32 - (-1 - shift)); } else { /* normal */ ir |= ia >> (24 - 11); ia = ia << (32 - (24 - 11)); ir = ir + ((14 + shift) << 10); } /* IEEE-754 round to nearest of even */ if ((ia > 0x80000000) || ((ia == 0x80000000) && (ir & 1))) { ir++; } } } this->__x = ir; } /// Cast to __half __host__ __device__ __forceinline__ operator __half() const { return reinterpret_cast(__x); } /// Cast to float __host__ __device__ __forceinline__ operator float() const { // Stolen from Andrew Kerr int sign = ((this->__x >> 15) & 1); int exp = ((this->__x >> 10) & 0x1f); int mantissa = (this->__x & 0x3ff); uint32_t f = 0; if (exp > 0 && exp < 31) { // normal exp += 112; f = (sign << 31) | (exp << 23) | (mantissa << 13); } else if (exp == 0) { if (mantissa) { // subnormal exp += 113; while ((mantissa & (1 << 10)) == 0) { mantissa <<= 1; exp--; } mantissa &= 0x3ff; f = (sign << 31) | (exp << 23) | (mantissa << 13); } else if (sign) { f = 0x80000000; // negative zero } else { f = 0x0; // zero } } else if (exp == 31) { if (mantissa) { f = 0x7fffffff; // not a number } else { f = (0xff << 23) | (sign << 31); // inf } } return *reinterpret_cast(&f); } /// Get raw storage __host__ __device__ __forceinline__ uint16_t raw() { return this->__x; } /// Equality __host__ __device__ __forceinline__ bool operator ==(const half_t &other) { return (this->__x == other.__x); } /// Inequality __host__ __device__ __forceinline__ bool operator !=(const half_t &other) { return (this->__x != other.__x); } /// Assignment by sum __host__ __device__ __forceinline__ half_t& operator +=(const half_t &rhs) { *this = half_t(float(*this) + float(rhs)); return *this; } /// Multiply __host__ __device__ __forceinline__ half_t operator*(const half_t &other) { return half_t(float(*this) * float(other)); } /// Add __host__ __device__ __forceinline__ half_t operator+(const half_t &other) { return half_t(float(*this) + float(other)); } /// Less-than __host__ __device__ __forceinline__ bool operator<(const half_t &other) const { return float(*this) < float(other); } /// Less-than-equal __host__ __device__ __forceinline__ bool operator<=(const half_t &other) const { return float(*this) <= float(other); } /// Greater-than __host__ __device__ __forceinline__ bool operator>(const half_t &other) const { return float(*this) > float(other); } /// Greater-than-equal __host__ __device__ __forceinline__ bool operator>=(const half_t &other) const { return float(*this) >= float(other); } /// numeric_traits::max __host__ __device__ __forceinline__ static half_t max() { uint16_t max_word = 0x7BFF; return reinterpret_cast(max_word); } /// numeric_traits::lowest __host__ __device__ __forceinline__ static half_t lowest() { uint16_t lowest_word = 0xFBFF; return reinterpret_cast(lowest_word); } }; /****************************************************************************** * I/O stream overloads ******************************************************************************/ /// Insert formatted \p half_t into the output stream std::ostream& operator<<(std::ostream &out, const half_t &x) { out << (float)x; return out; } /// Insert formatted \p __half into the output stream std::ostream& operator<<(std::ostream &out, const __half &x) { return out << half_t(x); } /****************************************************************************** * Traits overloads ******************************************************************************/ template <> struct cub::FpLimits { static __host__ __device__ __forceinline__ half_t Max() { return half_t::max(); } static __host__ __device__ __forceinline__ half_t Lowest() { return half_t::lowest(); } }; template <> struct cub::NumericTraits : cub::BaseTraits {}; #ifdef __GNUC__ #pragma GCC diagnostic pop #endif