LIVE / thrust /cub /block /block_radix_rank.cuh
Xu Ma
update
1c3c0d9
raw
history blame
25.2 kB
/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/**
* \file
* cub::BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block
*/
#pragma once
#include <stdint.h>
#include "../thread/thread_reduce.cuh"
#include "../thread/thread_scan.cuh"
#include "../block/block_scan.cuh"
#include "../config.cuh"
#include "../util_ptx.cuh"
#include "../util_type.cuh"
/// Optional outer namespace(s)
CUB_NS_PREFIX
/// CUB namespace
namespace cub {
/**
* \brief BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block.
* \ingroup BlockModule
*
* \tparam BLOCK_DIM_X The thread block length in threads along the X dimension
* \tparam RADIX_BITS The number of radix bits per digit place
* \tparam IS_DESCENDING Whether or not the sorted-order is high-to-low
* \tparam MEMOIZE_OUTER_SCAN <b>[optional]</b> Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure (default: true for architectures SM35 and newer, false otherwise). See BlockScanAlgorithm::BLOCK_SCAN_RAKING_MEMOIZE for more details.
* \tparam INNER_SCAN_ALGORITHM <b>[optional]</b> The cub::BlockScanAlgorithm algorithm to use (default: cub::BLOCK_SCAN_WARP_SCANS)
* \tparam SMEM_CONFIG <b>[optional]</b> Shared memory bank mode (default: \p cudaSharedMemBankSizeFourByte)
* \tparam BLOCK_DIM_Y <b>[optional]</b> The thread block length in threads along the Y dimension (default: 1)
* \tparam BLOCK_DIM_Z <b>[optional]</b> The thread block length in threads along the Z dimension (default: 1)
* \tparam PTX_ARCH <b>[optional]</b> \ptxversion
*
* \par Overview
* Blah...
* - Keys must be in a form suitable for radix ranking (i.e., unsigned bits).
* - \blocked
*
* \par Performance Considerations
* - \granularity
*
* \par Examples
* \par
* - <b>Example 1:</b> Simple radix rank of 32-bit integer keys
* \code
* #include <cub/cub.cuh>
*
* template <int BLOCK_THREADS>
* __global__ void ExampleKernel(...)
* {
*
* \endcode
*/
template <
int BLOCK_DIM_X,
int RADIX_BITS,
bool IS_DESCENDING,
bool MEMOIZE_OUTER_SCAN = (CUB_PTX_ARCH >= 350) ? true : false,
BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1,
int PTX_ARCH = CUB_PTX_ARCH>
class BlockRadixRank
{
private:
/******************************************************************************
* Type definitions and constants
******************************************************************************/
// Integer type for digit counters (to be packed into words of type PackedCounters)
typedef unsigned short DigitCounter;
// Integer type for packing DigitCounters into columns of shared memory banks
typedef typename If<(SMEM_CONFIG == cudaSharedMemBankSizeEightByte),
unsigned long long,
unsigned int>::Type PackedCounter;
enum
{
// The thread block size in threads
BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
RADIX_DIGITS = 1 << RADIX_BITS,
LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(PTX_ARCH),
WARP_THREADS = 1 << LOG_WARP_THREADS,
WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
BYTES_PER_COUNTER = sizeof(DigitCounter),
LOG_BYTES_PER_COUNTER = Log2<BYTES_PER_COUNTER>::VALUE,
PACKING_RATIO = sizeof(PackedCounter) / sizeof(DigitCounter),
LOG_PACKING_RATIO = Log2<PACKING_RATIO>::VALUE,
LOG_COUNTER_LANES = CUB_MAX((RADIX_BITS - LOG_PACKING_RATIO), 0), // Always at least one lane
COUNTER_LANES = 1 << LOG_COUNTER_LANES,
// The number of packed counters per thread (plus one for padding)
PADDED_COUNTER_LANES = COUNTER_LANES + 1,
RAKING_SEGMENT = PADDED_COUNTER_LANES,
};
public:
enum
{
/// Number of bin-starting offsets tracked per thread
BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS),
};
private:
/// BlockScan type
typedef BlockScan<
PackedCounter,
BLOCK_DIM_X,
INNER_SCAN_ALGORITHM,
BLOCK_DIM_Y,
BLOCK_DIM_Z,
PTX_ARCH>
BlockScan;
/// Shared memory storage layout type for BlockRadixRank
struct __align__(16) _TempStorage
{
union Aliasable
{
DigitCounter digit_counters[PADDED_COUNTER_LANES][BLOCK_THREADS][PACKING_RATIO];
PackedCounter raking_grid[BLOCK_THREADS][RAKING_SEGMENT];
} aliasable;
// Storage for scanning local ranks
typename BlockScan::TempStorage block_scan;
};
/******************************************************************************
* Thread fields
******************************************************************************/
/// Shared storage reference
_TempStorage &temp_storage;
/// Linear thread-id
unsigned int linear_tid;
/// Copy of raking segment, promoted to registers
PackedCounter cached_segment[RAKING_SEGMENT];
/******************************************************************************
* Utility methods
******************************************************************************/
/**
* Internal storage allocator
*/
__device__ __forceinline__ _TempStorage& PrivateStorage()
{
__shared__ _TempStorage private_storage;
return private_storage;
}
/**
* Performs upsweep raking reduction, returning the aggregate
*/
__device__ __forceinline__ PackedCounter Upsweep()
{
PackedCounter *smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid];
PackedCounter *raking_ptr;
if (MEMOIZE_OUTER_SCAN)
{
// Copy data into registers
#pragma unroll
for (int i = 0; i < RAKING_SEGMENT; i++)
{
cached_segment[i] = smem_raking_ptr[i];
}
raking_ptr = cached_segment;
}
else
{
raking_ptr = smem_raking_ptr;
}
return internal::ThreadReduce<RAKING_SEGMENT>(raking_ptr, Sum());
}
/// Performs exclusive downsweep raking scan
__device__ __forceinline__ void ExclusiveDownsweep(
PackedCounter raking_partial)
{
PackedCounter *smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid];
PackedCounter *raking_ptr = (MEMOIZE_OUTER_SCAN) ?
cached_segment :
smem_raking_ptr;
// Exclusive raking downsweep scan
internal::ThreadScanExclusive<RAKING_SEGMENT>(raking_ptr, raking_ptr, Sum(), raking_partial);
if (MEMOIZE_OUTER_SCAN)
{
// Copy data back to smem
#pragma unroll
for (int i = 0; i < RAKING_SEGMENT; i++)
{
smem_raking_ptr[i] = cached_segment[i];
}
}
}
/**
* Reset shared memory digit counters
*/
__device__ __forceinline__ void ResetCounters()
{
// Reset shared memory digit counters
#pragma unroll
for (int LANE = 0; LANE < PADDED_COUNTER_LANES; LANE++)
{
*((PackedCounter*) temp_storage.aliasable.digit_counters[LANE][linear_tid]) = 0;
}
}
/**
* Block-scan prefix callback
*/
struct PrefixCallBack
{
__device__ __forceinline__ PackedCounter operator()(PackedCounter block_aggregate)
{
PackedCounter block_prefix = 0;
// Propagate totals in packed fields
#pragma unroll
for (int PACKED = 1; PACKED < PACKING_RATIO; PACKED++)
{
block_prefix += block_aggregate << (sizeof(DigitCounter) * 8 * PACKED);
}
return block_prefix;
}
};
/**
* Scan shared memory digit counters.
*/
__device__ __forceinline__ void ScanCounters()
{
// Upsweep scan
PackedCounter raking_partial = Upsweep();
// Compute exclusive sum
PackedCounter exclusive_partial;
PrefixCallBack prefix_call_back;
BlockScan(temp_storage.block_scan).ExclusiveSum(raking_partial, exclusive_partial, prefix_call_back);
// Downsweep scan with exclusive partial
ExclusiveDownsweep(exclusive_partial);
}
public:
/// \smemstorage{BlockScan}
struct TempStorage : Uninitialized<_TempStorage> {};
/******************************************************************//**
* \name Collective constructors
*********************************************************************/
//@{
/**
* \brief Collective constructor using a private static allocation of shared memory as temporary storage.
*/
__device__ __forceinline__ BlockRadixRank()
:
temp_storage(PrivateStorage()),
linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
{}
/**
* \brief Collective constructor using the specified memory allocation as temporary storage.
*/
__device__ __forceinline__ BlockRadixRank(
TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage
:
temp_storage(temp_storage.Alias()),
linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
{}
//@} end member group
/******************************************************************//**
* \name Raking
*********************************************************************/
//@{
/**
* \brief Rank keys.
*/
template <
typename UnsignedBits,
int KEYS_PER_THREAD>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile
int current_bit, ///< [in] The least-significant bit position of the current digit to extract
int num_bits) ///< [in] The number of bits in the current digit
{
DigitCounter thread_prefixes[KEYS_PER_THREAD]; // For each key, the count of previous keys in this tile having the same digit
DigitCounter* digit_counters[KEYS_PER_THREAD]; // For each key, the byte-offset of its corresponding digit counter in smem
// Reset shared memory digit counters
ResetCounters();
#pragma unroll
for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
{
// Get digit
unsigned int digit = BFE(keys[ITEM], current_bit, num_bits);
// Get sub-counter
unsigned int sub_counter = digit >> LOG_COUNTER_LANES;
// Get counter lane
unsigned int counter_lane = digit & (COUNTER_LANES - 1);
if (IS_DESCENDING)
{
sub_counter = PACKING_RATIO - 1 - sub_counter;
counter_lane = COUNTER_LANES - 1 - counter_lane;
}
// Pointer to smem digit counter
digit_counters[ITEM] = &temp_storage.aliasable.digit_counters[counter_lane][linear_tid][sub_counter];
// Load thread-exclusive prefix
thread_prefixes[ITEM] = *digit_counters[ITEM];
// Store inclusive prefix
*digit_counters[ITEM] = thread_prefixes[ITEM] + 1;
}
CTA_SYNC();
// Scan shared memory counters
ScanCounters();
CTA_SYNC();
// Extract the local ranks of each key
for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
{
// Add in thread block exclusive prefix
ranks[ITEM] = thread_prefixes[ITEM] + *digit_counters[ITEM];
}
}
/**
* \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread.
*/
template <
typename UnsignedBits,
int KEYS_PER_THREAD>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter)
int current_bit, ///< [in] The least-significant bit position of the current digit to extract
int num_bits, ///< [in] The number of bits in the current digit
int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1]
{
// Rank keys
RankKeys(keys, ranks, current_bit, num_bits);
// Get the inclusive and exclusive digit totals corresponding to the calling thread.
#pragma unroll
for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
{
int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track;
if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
{
if (IS_DESCENDING)
bin_idx = RADIX_DIGITS - bin_idx - 1;
// Obtain ex/inclusive digit counts. (Unfortunately these all reside in the
// first counter column, resulting in unavoidable bank conflicts.)
unsigned int counter_lane = (bin_idx & (COUNTER_LANES - 1));
unsigned int sub_counter = bin_idx >> (LOG_COUNTER_LANES);
exclusive_digit_prefix[track] = temp_storage.aliasable.digit_counters[counter_lane][0][sub_counter];
}
}
}
};
/**
* Radix-rank using match.any
*/
template <
int BLOCK_DIM_X,
int RADIX_BITS,
bool IS_DESCENDING,
BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1,
int PTX_ARCH = CUB_PTX_ARCH>
class BlockRadixRankMatch
{
private:
/******************************************************************************
* Type definitions and constants
******************************************************************************/
typedef int32_t RankT;
typedef int32_t DigitCounterT;
enum
{
// The thread block size in threads
BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
RADIX_DIGITS = 1 << RADIX_BITS,
LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(PTX_ARCH),
WARP_THREADS = 1 << LOG_WARP_THREADS,
WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
PADDED_WARPS = ((WARPS & 0x1) == 0) ?
WARPS + 1 :
WARPS,
COUNTERS = PADDED_WARPS * RADIX_DIGITS,
RAKING_SEGMENT = (COUNTERS + BLOCK_THREADS - 1) / BLOCK_THREADS,
PADDED_RAKING_SEGMENT = ((RAKING_SEGMENT & 0x1) == 0) ?
RAKING_SEGMENT + 1 :
RAKING_SEGMENT,
};
public:
enum
{
/// Number of bin-starting offsets tracked per thread
BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS),
};
private:
/// BlockScan type
typedef BlockScan<
DigitCounterT,
BLOCK_THREADS,
INNER_SCAN_ALGORITHM,
BLOCK_DIM_Y,
BLOCK_DIM_Z,
PTX_ARCH>
BlockScanT;
/// Shared memory storage layout type for BlockRadixRank
struct __align__(16) _TempStorage
{
typename BlockScanT::TempStorage block_scan;
union __align__(16) Aliasable
{
volatile DigitCounterT warp_digit_counters[RADIX_DIGITS][PADDED_WARPS];
DigitCounterT raking_grid[BLOCK_THREADS][PADDED_RAKING_SEGMENT];
} aliasable;
};
/******************************************************************************
* Thread fields
******************************************************************************/
/// Shared storage reference
_TempStorage &temp_storage;
/// Linear thread-id
unsigned int linear_tid;
public:
/// \smemstorage{BlockScan}
struct TempStorage : Uninitialized<_TempStorage> {};
/******************************************************************//**
* \name Collective constructors
*********************************************************************/
//@{
/**
* \brief Collective constructor using the specified memory allocation as temporary storage.
*/
__device__ __forceinline__ BlockRadixRankMatch(
TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage
:
temp_storage(temp_storage.Alias()),
linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
{}
//@} end member group
/******************************************************************//**
* \name Raking
*********************************************************************/
//@{
/**
* \brief Rank keys.
*/
template <
typename UnsignedBits,
int KEYS_PER_THREAD>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile
int current_bit, ///< [in] The least-significant bit position of the current digit to extract
int num_bits) ///< [in] The number of bits in the current digit
{
// Initialize shared digit counters
#pragma unroll
for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
temp_storage.aliasable.raking_grid[linear_tid][ITEM] = 0;
CTA_SYNC();
// Each warp will strip-mine its section of input, one strip at a time
volatile DigitCounterT *digit_counters[KEYS_PER_THREAD];
uint32_t warp_id = linear_tid >> LOG_WARP_THREADS;
uint32_t lane_mask_lt = LaneMaskLt();
#pragma unroll
for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
{
// My digit
uint32_t digit = BFE(keys[ITEM], current_bit, num_bits);
if (IS_DESCENDING)
digit = RADIX_DIGITS - digit - 1;
// Mask of peers who have same digit as me
uint32_t peer_mask = MatchAny<RADIX_BITS>(digit);
// Pointer to smem digit counter for this key
digit_counters[ITEM] = &temp_storage.aliasable.warp_digit_counters[digit][warp_id];
// Number of occurrences in previous strips
DigitCounterT warp_digit_prefix = *digit_counters[ITEM];
// Warp-sync
WARP_SYNC(0xFFFFFFFF);
// Number of peers having same digit as me
int32_t digit_count = __popc(peer_mask);
// Number of lower-ranked peers having same digit seen so far
int32_t peer_digit_prefix = __popc(peer_mask & lane_mask_lt);
if (peer_digit_prefix == 0)
{
// First thread for each digit updates the shared warp counter
*digit_counters[ITEM] = DigitCounterT(warp_digit_prefix + digit_count);
}
// Warp-sync
WARP_SYNC(0xFFFFFFFF);
// Number of prior keys having same digit
ranks[ITEM] = warp_digit_prefix + DigitCounterT(peer_digit_prefix);
}
CTA_SYNC();
// Scan warp counters
DigitCounterT scan_counters[PADDED_RAKING_SEGMENT];
#pragma unroll
for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
scan_counters[ITEM] = temp_storage.aliasable.raking_grid[linear_tid][ITEM];
BlockScanT(temp_storage.block_scan).ExclusiveSum(scan_counters, scan_counters);
#pragma unroll
for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
temp_storage.aliasable.raking_grid[linear_tid][ITEM] = scan_counters[ITEM];
CTA_SYNC();
// Seed ranks with counter values from previous warps
#pragma unroll
for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
ranks[ITEM] += *digit_counters[ITEM];
}
/**
* \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread.
*/
template <
typename UnsignedBits,
int KEYS_PER_THREAD>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter)
int current_bit, ///< [in] The least-significant bit position of the current digit to extract
int num_bits, ///< [in] The number of bits in the current digit
int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1]
{
RankKeys(keys, ranks, current_bit, num_bits);
// Get exclusive count for each digit
#pragma unroll
for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
{
int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track;
if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
{
if (IS_DESCENDING)
bin_idx = RADIX_DIGITS - bin_idx - 1;
exclusive_digit_prefix[track] = temp_storage.aliasable.warp_digit_counters[bin_idx][0];
}
}
}
};
} // CUB namespace
CUB_NS_POSTFIX // Optional outer namespace(s)