LIVE / thrust /cub /agent /agent_segment_fixup.cuh
Xu Ma
update
1c3c0d9
raw
history blame
16.6 kB
/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/**
* \file
* cub::AgentSegmentFixup implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key.
*/
#pragma once
#include <iterator>
#include "single_pass_scan_operators.cuh"
#include "../block/block_load.cuh"
#include "../block/block_store.cuh"
#include "../block/block_scan.cuh"
#include "../block/block_discontinuity.cuh"
#include "../config.cuh"
#include "../iterator/cache_modified_input_iterator.cuh"
#include "../iterator/constant_input_iterator.cuh"
/// Optional outer namespace(s)
CUB_NS_PREFIX
/// CUB namespace
namespace cub {
/******************************************************************************
* Tuning policy types
******************************************************************************/
/**
* Parameterizable tuning policy type for AgentSegmentFixup
*/
template <
int _BLOCK_THREADS, ///< Threads per thread block
int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input)
BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use
CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements
BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use
struct AgentSegmentFixupPolicy
{
enum
{
BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block
ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input)
};
static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use
static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements
static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use
};
/******************************************************************************
* Thread block abstractions
******************************************************************************/
/**
* \brief AgentSegmentFixup implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key
*/
template <
typename AgentSegmentFixupPolicyT, ///< Parameterized AgentSegmentFixupPolicy tuning policy type
typename PairsInputIteratorT, ///< Random-access input iterator type for keys
typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values
typename EqualityOpT, ///< KeyT equality operator type
typename ReductionOpT, ///< ValueT reduction operator type
typename OffsetT> ///< Signed integer type for global offsets
struct AgentSegmentFixup
{
//---------------------------------------------------------------------
// Types and constants
//---------------------------------------------------------------------
// Data type of key-value input iterator
typedef typename std::iterator_traits<PairsInputIteratorT>::value_type KeyValuePairT;
// Value type
typedef typename KeyValuePairT::Value ValueT;
// Tile status descriptor interface type
typedef ReduceByKeyScanTileState<ValueT, OffsetT> ScanTileStateT;
// Constants
enum
{
BLOCK_THREADS = AgentSegmentFixupPolicyT::BLOCK_THREADS,
ITEMS_PER_THREAD = AgentSegmentFixupPolicyT::ITEMS_PER_THREAD,
TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD,
// Whether or not do fixup using RLE + global atomics
USE_ATOMIC_FIXUP = (CUB_PTX_ARCH >= 350) &&
(Equals<ValueT, float>::VALUE ||
Equals<ValueT, int>::VALUE ||
Equals<ValueT, unsigned int>::VALUE ||
Equals<ValueT, unsigned long long>::VALUE),
// Whether or not the scan operation has a zero-valued identity value (true if we're performing addition on a primitive type)
HAS_IDENTITY_ZERO = (Equals<ReductionOpT, cub::Sum>::VALUE) && (Traits<ValueT>::PRIMITIVE),
};
// Cache-modified Input iterator wrapper type (for applying cache modifier) for keys
typedef typename If<IsPointer<PairsInputIteratorT>::VALUE,
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, KeyValuePairT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator
PairsInputIteratorT>::Type // Directly use the supplied input iterator type
WrappedPairsInputIteratorT;
// Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values
typedef typename If<IsPointer<AggregatesOutputIteratorT>::VALUE,
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, ValueT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator
AggregatesOutputIteratorT>::Type // Directly use the supplied input iterator type
WrappedFixupInputIteratorT;
// Reduce-value-by-segment scan operator
typedef ReduceByKeyOp<cub::Sum> ReduceBySegmentOpT;
// Parameterized BlockLoad type for pairs
typedef BlockLoad<
KeyValuePairT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentSegmentFixupPolicyT::LOAD_ALGORITHM>
BlockLoadPairs;
// Parameterized BlockScan type
typedef BlockScan<
KeyValuePairT,
BLOCK_THREADS,
AgentSegmentFixupPolicyT::SCAN_ALGORITHM>
BlockScanT;
// Callback type for obtaining tile prefix during block scan
typedef TilePrefixCallbackOp<
KeyValuePairT,
ReduceBySegmentOpT,
ScanTileStateT>
TilePrefixCallbackOpT;
// Shared memory type for this thread block
union _TempStorage
{
struct
{
typename BlockScanT::TempStorage scan; // Smem needed for tile scanning
typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback
};
// Smem needed for loading keys
typename BlockLoadPairs::TempStorage load_pairs;
};
// Alias wrapper allowing storage to be unioned
struct TempStorage : Uninitialized<_TempStorage> {};
//---------------------------------------------------------------------
// Per-thread fields
//---------------------------------------------------------------------
_TempStorage& temp_storage; ///< Reference to temp_storage
WrappedPairsInputIteratorT d_pairs_in; ///< Input keys
AggregatesOutputIteratorT d_aggregates_out; ///< Output value aggregates
WrappedFixupInputIteratorT d_fixup_in; ///< Fixup input values
InequalityWrapper<EqualityOpT> inequality_op; ///< KeyT inequality operator
ReductionOpT reduction_op; ///< Reduction operator
ReduceBySegmentOpT scan_op; ///< Reduce-by-segment scan operator
//---------------------------------------------------------------------
// Constructor
//---------------------------------------------------------------------
// Constructor
__device__ __forceinline__
AgentSegmentFixup(
TempStorage& temp_storage, ///< Reference to temp_storage
PairsInputIteratorT d_pairs_in, ///< Input keys
AggregatesOutputIteratorT d_aggregates_out, ///< Output value aggregates
EqualityOpT equality_op, ///< KeyT equality operator
ReductionOpT reduction_op) ///< ValueT reduction operator
:
temp_storage(temp_storage.Alias()),
d_pairs_in(d_pairs_in),
d_aggregates_out(d_aggregates_out),
d_fixup_in(d_aggregates_out),
inequality_op(equality_op),
reduction_op(reduction_op),
scan_op(reduction_op)
{}
//---------------------------------------------------------------------
// Cooperatively scan a device-wide sequence of tiles with other CTAs
//---------------------------------------------------------------------
/**
* Process input tile. Specialized for atomic-fixup
*/
template <bool IS_LAST_TILE>
__device__ __forceinline__ void ConsumeTile(
OffsetT num_remaining, ///< Number of global input items remaining (including this tile)
int tile_idx, ///< Tile index
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state, ///< Global tile state descriptor
Int2Type<true> use_atomic_fixup) ///< Marker whether to use atomicAdd (instead of reduce-by-key)
{
KeyValuePairT pairs[ITEMS_PER_THREAD];
// Load pairs
KeyValuePairT oob_pair;
oob_pair.key = -1;
if (IS_LAST_TILE)
BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs, num_remaining, oob_pair);
else
BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs);
// RLE
#pragma unroll
for (int ITEM = 1; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
ValueT* d_scatter = d_aggregates_out + pairs[ITEM - 1].key;
if (pairs[ITEM].key != pairs[ITEM - 1].key)
atomicAdd(d_scatter, pairs[ITEM - 1].value);
else
pairs[ITEM].value = reduction_op(pairs[ITEM - 1].value, pairs[ITEM].value);
}
// Flush last item if valid
ValueT* d_scatter = d_aggregates_out + pairs[ITEMS_PER_THREAD - 1].key;
if ((!IS_LAST_TILE) || (pairs[ITEMS_PER_THREAD - 1].key >= 0))
atomicAdd(d_scatter, pairs[ITEMS_PER_THREAD - 1].value);
}
/**
* Process input tile. Specialized for reduce-by-key fixup
*/
template <bool IS_LAST_TILE>
__device__ __forceinline__ void ConsumeTile(
OffsetT num_remaining, ///< Number of global input items remaining (including this tile)
int tile_idx, ///< Tile index
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state, ///< Global tile state descriptor
Int2Type<false> use_atomic_fixup) ///< Marker whether to use atomicAdd (instead of reduce-by-key)
{
KeyValuePairT pairs[ITEMS_PER_THREAD];
KeyValuePairT scatter_pairs[ITEMS_PER_THREAD];
// Load pairs
KeyValuePairT oob_pair;
oob_pair.key = -1;
if (IS_LAST_TILE)
BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs, num_remaining, oob_pair);
else
BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs);
CTA_SYNC();
KeyValuePairT tile_aggregate;
if (tile_idx == 0)
{
// Exclusive scan of values and segment_flags
BlockScanT(temp_storage.scan).ExclusiveScan(pairs, scatter_pairs, scan_op, tile_aggregate);
// Update tile status if this is not the last tile
if (threadIdx.x == 0)
{
// Set first segment id to not trigger a flush (invalid from exclusive scan)
scatter_pairs[0].key = pairs[0].key;
if (!IS_LAST_TILE)
tile_state.SetInclusive(0, tile_aggregate);
}
}
else
{
// Exclusive scan of values and segment_flags
TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx);
BlockScanT(temp_storage.scan).ExclusiveScan(pairs, scatter_pairs, scan_op, prefix_op);
tile_aggregate = prefix_op.GetBlockAggregate();
}
// Scatter updated values
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
if (scatter_pairs[ITEM].key != pairs[ITEM].key)
{
// Update the value at the key location
ValueT value = d_fixup_in[scatter_pairs[ITEM].key];
value = reduction_op(value, scatter_pairs[ITEM].value);
d_aggregates_out[scatter_pairs[ITEM].key] = value;
}
}
// Finalize the last item
if (IS_LAST_TILE)
{
// Last thread will output final count and last item, if necessary
if (threadIdx.x == BLOCK_THREADS - 1)
{
// If the last tile is a whole tile, the inclusive prefix contains accumulated value reduction for the last segment
if (num_remaining == TILE_ITEMS)
{
// Update the value at the key location
OffsetT last_key = pairs[ITEMS_PER_THREAD - 1].key;
d_aggregates_out[last_key] = reduction_op(tile_aggregate.value, d_fixup_in[last_key]);
}
}
}
}
/**
* Scan tiles of items as part of a dynamic chained scan
*/
__device__ __forceinline__ void ConsumeRange(
int num_items, ///< Total number of input items
int num_tiles, ///< Total number of input tiles
ScanTileStateT& tile_state) ///< Global tile state descriptor
{
// Blocks are launched in increasing order, so just assign one tile per block
int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index
OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile
OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile)
if (num_remaining > TILE_ITEMS)
{
// Not the last tile (full)
ConsumeTile<false>(num_remaining, tile_idx, tile_offset, tile_state, Int2Type<USE_ATOMIC_FIXUP>());
}
else if (num_remaining > 0)
{
// The last tile (possibly partially-full)
ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state, Int2Type<USE_ATOMIC_FIXUP>());
}
}
};
} // CUB namespace
CUB_NS_POSTFIX // Optional outer namespace(s)