Spaces:
Runtime error
Runtime error
/****************************************************************************** | |
* Copyright (c) 2011, Duane Merrill. All rights reserved. | |
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* * Redistributions of source code must retain the above copyright | |
* notice, this list of conditions and the following disclaimer. | |
* * Redistributions in binary form must reproduce the above copyright | |
* notice, this list of conditions and the following disclaimer in the | |
* documentation and/or other materials provided with the distribution. | |
* * Neither the name of the NVIDIA CORPORATION nor the | |
* names of its contributors may be used to endorse or promote products | |
* derived from this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | |
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | |
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | |
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | |
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | |
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
******************************************************************************/ | |
/** | |
* \file | |
* Callback operator types for supplying BlockScan prefixes | |
*/ | |
#pragma once | |
#include <iterator> | |
#include "../thread/thread_load.cuh" | |
#include "../thread/thread_store.cuh" | |
#include "../warp/warp_reduce.cuh" | |
#include "../config.cuh" | |
#include "../util_device.cuh" | |
/// Optional outer namespace(s) | |
CUB_NS_PREFIX | |
/// CUB namespace | |
namespace cub { | |
/****************************************************************************** | |
* Prefix functor type for maintaining a running prefix while scanning a | |
* region independent of other thread blocks | |
******************************************************************************/ | |
/** | |
* Stateful callback operator type for supplying BlockScan prefixes. | |
* Maintains a running prefix that can be applied to consecutive | |
* BlockScan operations. | |
*/ | |
template < | |
typename T, ///< BlockScan value type | |
typename ScanOpT> ///< Wrapped scan operator type | |
struct BlockScanRunningPrefixOp | |
{ | |
ScanOpT op; ///< Wrapped scan operator | |
T running_total; ///< Running block-wide prefix | |
/// Constructor | |
__device__ __forceinline__ BlockScanRunningPrefixOp(ScanOpT op) | |
: | |
op(op) | |
{} | |
/// Constructor | |
__device__ __forceinline__ BlockScanRunningPrefixOp( | |
T starting_prefix, | |
ScanOpT op) | |
: | |
op(op), | |
running_total(starting_prefix) | |
{} | |
/** | |
* Prefix callback operator. Returns the block-wide running_total in thread-0. | |
*/ | |
__device__ __forceinline__ T operator()( | |
const T &block_aggregate) ///< The aggregate sum of the BlockScan inputs | |
{ | |
T retval = running_total; | |
running_total = op(running_total, block_aggregate); | |
return retval; | |
} | |
}; | |
/****************************************************************************** | |
* Generic tile status interface types for block-cooperative scans | |
******************************************************************************/ | |
/** | |
* Enumerations of tile status | |
*/ | |
enum ScanTileStatus | |
{ | |
SCAN_TILE_OOB, // Out-of-bounds (e.g., padding) | |
SCAN_TILE_INVALID = 99, // Not yet processed | |
SCAN_TILE_PARTIAL, // Tile aggregate is available | |
SCAN_TILE_INCLUSIVE, // Inclusive tile prefix is available | |
}; | |
/** | |
* Tile status interface. | |
*/ | |
template < | |
typename T, | |
bool SINGLE_WORD = Traits<T>::PRIMITIVE> | |
struct ScanTileState; | |
/** | |
* Tile status interface specialized for scan status and value types | |
* that can be combined into one machine word that can be | |
* read/written coherently in a single access. | |
*/ | |
template <typename T> | |
struct ScanTileState<T, true> | |
{ | |
// Status word type | |
typedef typename If<(sizeof(T) == 8), | |
long long, | |
typename If<(sizeof(T) == 4), | |
int, | |
typename If<(sizeof(T) == 2), | |
short, | |
char>::Type>::Type>::Type StatusWord; | |
// Unit word type | |
typedef typename If<(sizeof(T) == 8), | |
longlong2, | |
typename If<(sizeof(T) == 4), | |
int2, | |
typename If<(sizeof(T) == 2), | |
int, | |
uchar2>::Type>::Type>::Type TxnWord; | |
// Device word type | |
struct TileDescriptor | |
{ | |
StatusWord status; | |
T value; | |
}; | |
// Constants | |
enum | |
{ | |
TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, | |
}; | |
// Device storage | |
TxnWord *d_tile_descriptors; | |
/// Constructor | |
__host__ __device__ __forceinline__ | |
ScanTileState() | |
: | |
d_tile_descriptors(NULL) | |
{} | |
/// Initializer | |
__host__ __device__ __forceinline__ | |
cudaError_t Init( | |
int /*num_tiles*/, ///< [in] Number of tiles | |
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. | |
size_t /*temp_storage_bytes*/) ///< [in] Size in bytes of \t d_temp_storage allocation | |
{ | |
d_tile_descriptors = reinterpret_cast<TxnWord*>(d_temp_storage); | |
return cudaSuccess; | |
} | |
/** | |
* Compute device memory needed for tile status | |
*/ | |
__host__ __device__ __forceinline__ | |
static cudaError_t AllocationSize( | |
int num_tiles, ///< [in] Number of tiles | |
size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation | |
{ | |
temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TileDescriptor); // bytes needed for tile status descriptors | |
return cudaSuccess; | |
} | |
/** | |
* Initialize (from device) | |
*/ | |
__device__ __forceinline__ void InitializeStatus(int num_tiles) | |
{ | |
int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; | |
TxnWord val = TxnWord(); | |
TileDescriptor *descriptor = reinterpret_cast<TileDescriptor*>(&val); | |
if (tile_idx < num_tiles) | |
{ | |
// Not-yet-set | |
descriptor->status = StatusWord(SCAN_TILE_INVALID); | |
d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val; | |
} | |
if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) | |
{ | |
// Padding | |
descriptor->status = StatusWord(SCAN_TILE_OOB); | |
d_tile_descriptors[threadIdx.x] = val; | |
} | |
} | |
/** | |
* Update the specified tile's inclusive value and corresponding status | |
*/ | |
__device__ __forceinline__ void SetInclusive(int tile_idx, T tile_inclusive) | |
{ | |
TileDescriptor tile_descriptor; | |
tile_descriptor.status = SCAN_TILE_INCLUSIVE; | |
tile_descriptor.value = tile_inclusive; | |
TxnWord alias; | |
*reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; | |
ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); | |
} | |
/** | |
* Update the specified tile's partial value and corresponding status | |
*/ | |
__device__ __forceinline__ void SetPartial(int tile_idx, T tile_partial) | |
{ | |
TileDescriptor tile_descriptor; | |
tile_descriptor.status = SCAN_TILE_PARTIAL; | |
tile_descriptor.value = tile_partial; | |
TxnWord alias; | |
*reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; | |
ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); | |
} | |
/** | |
* Wait for the corresponding tile to become non-invalid | |
*/ | |
__device__ __forceinline__ void WaitForValid( | |
int tile_idx, | |
StatusWord &status, | |
T &value) | |
{ | |
TileDescriptor tile_descriptor; | |
do | |
{ | |
__threadfence_block(); // prevent hoisting loads from loop | |
TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); | |
tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); | |
} while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff)); | |
status = tile_descriptor.status; | |
value = tile_descriptor.value; | |
} | |
}; | |
/** | |
* Tile status interface specialized for scan status and value types that | |
* cannot be combined into one machine word. | |
*/ | |
template <typename T> | |
struct ScanTileState<T, false> | |
{ | |
// Status word type | |
typedef char StatusWord; | |
// Constants | |
enum | |
{ | |
TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, | |
}; | |
// Device storage | |
StatusWord *d_tile_status; | |
T *d_tile_partial; | |
T *d_tile_inclusive; | |
/// Constructor | |
__host__ __device__ __forceinline__ | |
ScanTileState() | |
: | |
d_tile_status(NULL), | |
d_tile_partial(NULL), | |
d_tile_inclusive(NULL) | |
{} | |
/// Initializer | |
__host__ __device__ __forceinline__ | |
cudaError_t Init( | |
int num_tiles, ///< [in] Number of tiles | |
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. | |
size_t temp_storage_bytes) ///< [in] Size in bytes of \t d_temp_storage allocation | |
{ | |
cudaError_t error = cudaSuccess; | |
do | |
{ | |
void* allocations[3] = {}; | |
size_t allocation_sizes[3]; | |
allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord); // bytes needed for tile status descriptors | |
allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for partials | |
allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for inclusives | |
// Compute allocation pointers into the single storage blob | |
if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; | |
// Alias the offsets | |
d_tile_status = reinterpret_cast<StatusWord*>(allocations[0]); | |
d_tile_partial = reinterpret_cast<T*>(allocations[1]); | |
d_tile_inclusive = reinterpret_cast<T*>(allocations[2]); | |
} | |
while (0); | |
return error; | |
} | |
/** | |
* Compute device memory needed for tile status | |
*/ | |
__host__ __device__ __forceinline__ | |
static cudaError_t AllocationSize( | |
int num_tiles, ///< [in] Number of tiles | |
size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation | |
{ | |
// Specify storage allocation requirements | |
size_t allocation_sizes[3]; | |
allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord); // bytes needed for tile status descriptors | |
allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for partials | |
allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for inclusives | |
// Set the necessary size of the blob | |
void* allocations[3] = {}; | |
return CubDebug(AliasTemporaries(NULL, temp_storage_bytes, allocations, allocation_sizes)); | |
} | |
/** | |
* Initialize (from device) | |
*/ | |
__device__ __forceinline__ void InitializeStatus(int num_tiles) | |
{ | |
int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; | |
if (tile_idx < num_tiles) | |
{ | |
// Not-yet-set | |
d_tile_status[TILE_STATUS_PADDING + tile_idx] = StatusWord(SCAN_TILE_INVALID); | |
} | |
if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) | |
{ | |
// Padding | |
d_tile_status[threadIdx.x] = StatusWord(SCAN_TILE_OOB); | |
} | |
} | |
/** | |
* Update the specified tile's inclusive value and corresponding status | |
*/ | |
__device__ __forceinline__ void SetInclusive(int tile_idx, T tile_inclusive) | |
{ | |
// Update tile inclusive value | |
ThreadStore<STORE_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx, tile_inclusive); | |
// Fence | |
__threadfence(); | |
// Update tile status | |
ThreadStore<STORE_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_INCLUSIVE)); | |
} | |
/** | |
* Update the specified tile's partial value and corresponding status | |
*/ | |
__device__ __forceinline__ void SetPartial(int tile_idx, T tile_partial) | |
{ | |
// Update tile partial value | |
ThreadStore<STORE_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx, tile_partial); | |
// Fence | |
__threadfence(); | |
// Update tile status | |
ThreadStore<STORE_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_PARTIAL)); | |
} | |
/** | |
* Wait for the corresponding tile to become non-invalid | |
*/ | |
__device__ __forceinline__ void WaitForValid( | |
int tile_idx, | |
StatusWord &status, | |
T &value) | |
{ | |
do { | |
status = ThreadLoad<LOAD_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx); | |
__threadfence(); // prevent hoisting loads from loop or loads below above this one | |
} while (status == SCAN_TILE_INVALID); | |
if (status == StatusWord(SCAN_TILE_PARTIAL)) | |
value = ThreadLoad<LOAD_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx); | |
else | |
value = ThreadLoad<LOAD_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx); | |
} | |
}; | |
/****************************************************************************** | |
* ReduceByKey tile status interface types for block-cooperative scans | |
******************************************************************************/ | |
/** | |
* Tile status interface for reduction by key. | |
* | |
*/ | |
template < | |
typename ValueT, | |
typename KeyT, | |
bool SINGLE_WORD = (Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16)> | |
struct ReduceByKeyScanTileState; | |
/** | |
* Tile status interface for reduction by key, specialized for scan status and value types that | |
* cannot be combined into one machine word. | |
*/ | |
template < | |
typename ValueT, | |
typename KeyT> | |
struct ReduceByKeyScanTileState<ValueT, KeyT, false> : | |
ScanTileState<KeyValuePair<KeyT, ValueT> > | |
{ | |
typedef ScanTileState<KeyValuePair<KeyT, ValueT> > SuperClass; | |
/// Constructor | |
__host__ __device__ __forceinline__ | |
ReduceByKeyScanTileState() : SuperClass() {} | |
}; | |
/** | |
* Tile status interface for reduction by key, specialized for scan status and value types that | |
* can be combined into one machine word that can be read/written coherently in a single access. | |
*/ | |
template < | |
typename ValueT, | |
typename KeyT> | |
struct ReduceByKeyScanTileState<ValueT, KeyT, true> | |
{ | |
typedef KeyValuePair<KeyT, ValueT>KeyValuePairT; | |
// Constants | |
enum | |
{ | |
PAIR_SIZE = sizeof(ValueT) + sizeof(KeyT), | |
TXN_WORD_SIZE = 1 << Log2<PAIR_SIZE + 1>::VALUE, | |
STATUS_WORD_SIZE = TXN_WORD_SIZE - PAIR_SIZE, | |
TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, | |
}; | |
// Status word type | |
typedef typename If<(STATUS_WORD_SIZE == 8), | |
long long, | |
typename If<(STATUS_WORD_SIZE == 4), | |
int, | |
typename If<(STATUS_WORD_SIZE == 2), | |
short, | |
char>::Type>::Type>::Type StatusWord; | |
// Status word type | |
typedef typename If<(TXN_WORD_SIZE == 16), | |
longlong2, | |
typename If<(TXN_WORD_SIZE == 8), | |
long long, | |
int>::Type>::Type TxnWord; | |
// Device word type (for when sizeof(ValueT) == sizeof(KeyT)) | |
struct TileDescriptorBigStatus | |
{ | |
KeyT key; | |
ValueT value; | |
StatusWord status; | |
}; | |
// Device word type (for when sizeof(ValueT) != sizeof(KeyT)) | |
struct TileDescriptorLittleStatus | |
{ | |
ValueT value; | |
StatusWord status; | |
KeyT key; | |
}; | |
// Device word type | |
typedef typename If< | |
(sizeof(ValueT) == sizeof(KeyT)), | |
TileDescriptorBigStatus, | |
TileDescriptorLittleStatus>::Type | |
TileDescriptor; | |
// Device storage | |
TxnWord *d_tile_descriptors; | |
/// Constructor | |
__host__ __device__ __forceinline__ | |
ReduceByKeyScanTileState() | |
: | |
d_tile_descriptors(NULL) | |
{} | |
/// Initializer | |
__host__ __device__ __forceinline__ | |
cudaError_t Init( | |
int /*num_tiles*/, ///< [in] Number of tiles | |
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. | |
size_t /*temp_storage_bytes*/) ///< [in] Size in bytes of \t d_temp_storage allocation | |
{ | |
d_tile_descriptors = reinterpret_cast<TxnWord*>(d_temp_storage); | |
return cudaSuccess; | |
} | |
/** | |
* Compute device memory needed for tile status | |
*/ | |
__host__ __device__ __forceinline__ | |
static cudaError_t AllocationSize( | |
int num_tiles, ///< [in] Number of tiles | |
size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation | |
{ | |
temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TileDescriptor); // bytes needed for tile status descriptors | |
return cudaSuccess; | |
} | |
/** | |
* Initialize (from device) | |
*/ | |
__device__ __forceinline__ void InitializeStatus(int num_tiles) | |
{ | |
int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; | |
TxnWord val = TxnWord(); | |
TileDescriptor *descriptor = reinterpret_cast<TileDescriptor*>(&val); | |
if (tile_idx < num_tiles) | |
{ | |
// Not-yet-set | |
descriptor->status = StatusWord(SCAN_TILE_INVALID); | |
d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val; | |
} | |
if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) | |
{ | |
// Padding | |
descriptor->status = StatusWord(SCAN_TILE_OOB); | |
d_tile_descriptors[threadIdx.x] = val; | |
} | |
} | |
/** | |
* Update the specified tile's inclusive value and corresponding status | |
*/ | |
__device__ __forceinline__ void SetInclusive(int tile_idx, KeyValuePairT tile_inclusive) | |
{ | |
TileDescriptor tile_descriptor; | |
tile_descriptor.status = SCAN_TILE_INCLUSIVE; | |
tile_descriptor.value = tile_inclusive.value; | |
tile_descriptor.key = tile_inclusive.key; | |
TxnWord alias; | |
*reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; | |
ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); | |
} | |
/** | |
* Update the specified tile's partial value and corresponding status | |
*/ | |
__device__ __forceinline__ void SetPartial(int tile_idx, KeyValuePairT tile_partial) | |
{ | |
TileDescriptor tile_descriptor; | |
tile_descriptor.status = SCAN_TILE_PARTIAL; | |
tile_descriptor.value = tile_partial.value; | |
tile_descriptor.key = tile_partial.key; | |
TxnWord alias; | |
*reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; | |
ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); | |
} | |
/** | |
* Wait for the corresponding tile to become non-invalid | |
*/ | |
__device__ __forceinline__ void WaitForValid( | |
int tile_idx, | |
StatusWord &status, | |
KeyValuePairT &value) | |
{ | |
// TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); | |
// TileDescriptor tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); | |
// | |
// while (tile_descriptor.status == SCAN_TILE_INVALID) | |
// { | |
// __threadfence_block(); // prevent hoisting loads from loop | |
// | |
// alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); | |
// tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); | |
// } | |
// | |
// status = tile_descriptor.status; | |
// value.value = tile_descriptor.value; | |
// value.key = tile_descriptor.key; | |
TileDescriptor tile_descriptor; | |
do | |
{ | |
__threadfence_block(); // prevent hoisting loads from loop | |
TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); | |
tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); | |
} while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff)); | |
status = tile_descriptor.status; | |
value.value = tile_descriptor.value; | |
value.key = tile_descriptor.key; | |
} | |
}; | |
/****************************************************************************** | |
* Prefix call-back operator for coupling local block scan within a | |
* block-cooperative scan | |
******************************************************************************/ | |
/** | |
* Stateful block-scan prefix functor. Provides the the running prefix for | |
* the current tile by using the call-back warp to wait on on | |
* aggregates/prefixes from predecessor tiles to become available. | |
*/ | |
template < | |
typename T, | |
typename ScanOpT, | |
typename ScanTileStateT, | |
int PTX_ARCH = CUB_PTX_ARCH> | |
struct TilePrefixCallbackOp | |
{ | |
// Parameterized warp reduce | |
typedef WarpReduce<T, CUB_PTX_WARP_THREADS, PTX_ARCH> WarpReduceT; | |
// Temporary storage type | |
struct _TempStorage | |
{ | |
typename WarpReduceT::TempStorage warp_reduce; | |
T exclusive_prefix; | |
T inclusive_prefix; | |
T block_aggregate; | |
}; | |
// Alias wrapper allowing temporary storage to be unioned | |
struct TempStorage : Uninitialized<_TempStorage> {}; | |
// Type of status word | |
typedef typename ScanTileStateT::StatusWord StatusWord; | |
// Fields | |
_TempStorage& temp_storage; ///< Reference to a warp-reduction instance | |
ScanTileStateT& tile_status; ///< Interface to tile status | |
ScanOpT scan_op; ///< Binary scan operator | |
int tile_idx; ///< The current tile index | |
T exclusive_prefix; ///< Exclusive prefix for the tile | |
T inclusive_prefix; ///< Inclusive prefix for the tile | |
// Constructor | |
__device__ __forceinline__ | |
TilePrefixCallbackOp( | |
ScanTileStateT &tile_status, | |
TempStorage &temp_storage, | |
ScanOpT scan_op, | |
int tile_idx) | |
: | |
temp_storage(temp_storage.Alias()), | |
tile_status(tile_status), | |
scan_op(scan_op), | |
tile_idx(tile_idx) {} | |
// Block until all predecessors within the warp-wide window have non-invalid status | |
__device__ __forceinline__ | |
void ProcessWindow( | |
int predecessor_idx, ///< Preceding tile index to inspect | |
StatusWord &predecessor_status, ///< [out] Preceding tile status | |
T &window_aggregate) ///< [out] Relevant partial reduction from this window of preceding tiles | |
{ | |
T value; | |
tile_status.WaitForValid(predecessor_idx, predecessor_status, value); | |
// Perform a segmented reduction to get the prefix for the current window. | |
// Use the swizzled scan operator because we are now scanning *down* towards thread0. | |
int tail_flag = (predecessor_status == StatusWord(SCAN_TILE_INCLUSIVE)); | |
window_aggregate = WarpReduceT(temp_storage.warp_reduce).TailSegmentedReduce( | |
value, | |
tail_flag, | |
SwizzleScanOp<ScanOpT>(scan_op)); | |
} | |
// BlockScan prefix callback functor (called by the first warp) | |
__device__ __forceinline__ | |
T operator()(T block_aggregate) | |
{ | |
// Update our status with our tile-aggregate | |
if (threadIdx.x == 0) | |
{ | |
temp_storage.block_aggregate = block_aggregate; | |
tile_status.SetPartial(tile_idx, block_aggregate); | |
} | |
int predecessor_idx = tile_idx - threadIdx.x - 1; | |
StatusWord predecessor_status; | |
T window_aggregate; | |
// Wait for the warp-wide window of predecessor tiles to become valid | |
ProcessWindow(predecessor_idx, predecessor_status, window_aggregate); | |
// The exclusive tile prefix starts out as the current window aggregate | |
exclusive_prefix = window_aggregate; | |
// Keep sliding the window back until we come across a tile whose inclusive prefix is known | |
while (WARP_ALL((predecessor_status != StatusWord(SCAN_TILE_INCLUSIVE)), 0xffffffff)) | |
{ | |
predecessor_idx -= CUB_PTX_WARP_THREADS; | |
// Update exclusive tile prefix with the window prefix | |
ProcessWindow(predecessor_idx, predecessor_status, window_aggregate); | |
exclusive_prefix = scan_op(window_aggregate, exclusive_prefix); | |
} | |
// Compute the inclusive tile prefix and update the status for this tile | |
if (threadIdx.x == 0) | |
{ | |
inclusive_prefix = scan_op(exclusive_prefix, block_aggregate); | |
tile_status.SetInclusive(tile_idx, inclusive_prefix); | |
temp_storage.exclusive_prefix = exclusive_prefix; | |
temp_storage.inclusive_prefix = inclusive_prefix; | |
} | |
// Return exclusive_prefix | |
return exclusive_prefix; | |
} | |
// Get the exclusive prefix stored in temporary storage | |
__device__ __forceinline__ | |
T GetExclusivePrefix() | |
{ | |
return temp_storage.exclusive_prefix; | |
} | |
// Get the inclusive prefix stored in temporary storage | |
__device__ __forceinline__ | |
T GetInclusivePrefix() | |
{ | |
return temp_storage.inclusive_prefix; | |
} | |
// Get the block aggregate stored in temporary storage | |
__device__ __forceinline__ | |
T GetBlockAggregate() | |
{ | |
return temp_storage.block_aggregate; | |
} | |
}; | |
} // CUB namespace | |
CUB_NS_POSTFIX // Optional outer namespace(s) | |