Xu Ma
update
1c3c0d9
raw
history blame
No virus
9.31 kB
/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Utilities for interacting with the opaque CUDA __half type
*/
#include <stdint.h>
#include <cuda_fp16.h>
#include <iosfwd>
#include <cub/util_type.cuh>
#ifdef __GNUC__
// There's a ton of type-punning going on in this file.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
/******************************************************************************
* half_t
******************************************************************************/
/**
* Host-based fp16 data type compatible and convertible with __half
*/
struct half_t
{
uint16_t __x;
/// Constructor from __half
__host__ __device__ __forceinline__
half_t(const __half &other)
{
__x = reinterpret_cast<const uint16_t&>(other);
}
/// Constructor from integer
__host__ __device__ __forceinline__
half_t(int a)
{
*this = half_t(float(a));
}
/// Default constructor
__host__ __device__ __forceinline__
half_t() : __x(0)
{}
/// Constructor from float
__host__ __device__ __forceinline__
half_t(float a)
{
// Stolen from Norbert Juffa
uint32_t ia = *reinterpret_cast<uint32_t*>(&a);
uint16_t ir;
ir = (ia >> 16) & 0x8000;
if ((ia & 0x7f800000) == 0x7f800000)
{
if ((ia & 0x7fffffff) == 0x7f800000)
{
ir |= 0x7c00; /* infinity */
}
else
{
ir = 0x7fff; /* canonical NaN */
}
}
else if ((ia & 0x7f800000) >= 0x33000000)
{
int32_t shift = (int32_t) ((ia >> 23) & 0xff) - 127;
if (shift > 15)
{
ir |= 0x7c00; /* infinity */
}
else
{
ia = (ia & 0x007fffff) | 0x00800000; /* extract mantissa */
if (shift < -14)
{ /* denormal */
ir |= ia >> (-1 - shift);
ia = ia << (32 - (-1 - shift));
}
else
{ /* normal */
ir |= ia >> (24 - 11);
ia = ia << (32 - (24 - 11));
ir = ir + ((14 + shift) << 10);
}
/* IEEE-754 round to nearest of even */
if ((ia > 0x80000000) || ((ia == 0x80000000) && (ir & 1)))
{
ir++;
}
}
}
this->__x = ir;
}
/// Cast to __half
__host__ __device__ __forceinline__
operator __half() const
{
return reinterpret_cast<const __half&>(__x);
}
/// Cast to float
__host__ __device__ __forceinline__
operator float() const
{
// Stolen from Andrew Kerr
int sign = ((this->__x >> 15) & 1);
int exp = ((this->__x >> 10) & 0x1f);
int mantissa = (this->__x & 0x3ff);
uint32_t f = 0;
if (exp > 0 && exp < 31)
{
// normal
exp += 112;
f = (sign << 31) | (exp << 23) | (mantissa << 13);
}
else if (exp == 0)
{
if (mantissa)
{
// subnormal
exp += 113;
while ((mantissa & (1 << 10)) == 0)
{
mantissa <<= 1;
exp--;
}
mantissa &= 0x3ff;
f = (sign << 31) | (exp << 23) | (mantissa << 13);
}
else if (sign)
{
f = 0x80000000; // negative zero
}
else
{
f = 0x0; // zero
}
}
else if (exp == 31)
{
if (mantissa)
{
f = 0x7fffffff; // not a number
}
else
{
f = (0xff << 23) | (sign << 31); // inf
}
}
return *reinterpret_cast<float const *>(&f);
}
/// Get raw storage
__host__ __device__ __forceinline__
uint16_t raw()
{
return this->__x;
}
/// Equality
__host__ __device__ __forceinline__
bool operator ==(const half_t &other)
{
return (this->__x == other.__x);
}
/// Inequality
__host__ __device__ __forceinline__
bool operator !=(const half_t &other)
{
return (this->__x != other.__x);
}
/// Assignment by sum
__host__ __device__ __forceinline__
half_t& operator +=(const half_t &rhs)
{
*this = half_t(float(*this) + float(rhs));
return *this;
}
/// Multiply
__host__ __device__ __forceinline__
half_t operator*(const half_t &other)
{
return half_t(float(*this) * float(other));
}
/// Add
__host__ __device__ __forceinline__
half_t operator+(const half_t &other)
{
return half_t(float(*this) + float(other));
}
/// Less-than
__host__ __device__ __forceinline__
bool operator<(const half_t &other) const
{
return float(*this) < float(other);
}
/// Less-than-equal
__host__ __device__ __forceinline__
bool operator<=(const half_t &other) const
{
return float(*this) <= float(other);
}
/// Greater-than
__host__ __device__ __forceinline__
bool operator>(const half_t &other) const
{
return float(*this) > float(other);
}
/// Greater-than-equal
__host__ __device__ __forceinline__
bool operator>=(const half_t &other) const
{
return float(*this) >= float(other);
}
/// numeric_traits<half_t>::max
__host__ __device__ __forceinline__
static half_t max() {
uint16_t max_word = 0x7BFF;
return reinterpret_cast<half_t&>(max_word);
}
/// numeric_traits<half_t>::lowest
__host__ __device__ __forceinline__
static half_t lowest() {
uint16_t lowest_word = 0xFBFF;
return reinterpret_cast<half_t&>(lowest_word);
}
};
/******************************************************************************
* I/O stream overloads
******************************************************************************/
/// Insert formatted \p half_t into the output stream
std::ostream& operator<<(std::ostream &out, const half_t &x)
{
out << (float)x;
return out;
}
/// Insert formatted \p __half into the output stream
std::ostream& operator<<(std::ostream &out, const __half &x)
{
return out << half_t(x);
}
/******************************************************************************
* Traits overloads
******************************************************************************/
template <>
struct cub::FpLimits<half_t>
{
static __host__ __device__ __forceinline__ half_t Max() { return half_t::max(); }
static __host__ __device__ __forceinline__ half_t Lowest() { return half_t::lowest(); }
};
template <> struct cub::NumericTraits<half_t> : cub::BaseTraits<FLOATING_POINT, true, false, unsigned short, half_t> {};
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif