Spaces:
Runtime error
Runtime error
/****************************************************************************** | |
* Copyright (c) 2011, Duane Merrill. All rights reserved. | |
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* * Redistributions of source code must retain the above copyright | |
* notice, this list of conditions and the following disclaimer. | |
* * Redistributions in binary form must reproduce the above copyright | |
* notice, this list of conditions and the following disclaimer in the | |
* documentation and/or other materials provided with the distribution. | |
* * Neither the name of the NVIDIA CORPORATION nor the | |
* names of its contributors may be used to endorse or promote products | |
* derived from this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | |
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | |
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | |
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | |
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | |
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
******************************************************************************/ | |
/** | |
* \file | |
* cub::AgentSpmv implements a stateful abstraction of CUDA thread blocks for participating in device-wide SpMV. | |
*/ | |
#pragma once | |
#include <iterator> | |
#include "../util_type.cuh" | |
#include "../block/block_reduce.cuh" | |
#include "../block/block_scan.cuh" | |
#include "../block/block_exchange.cuh" | |
#include "../config.cuh" | |
#include "../thread/thread_search.cuh" | |
#include "../thread/thread_operators.cuh" | |
#include "../iterator/cache_modified_input_iterator.cuh" | |
#include "../iterator/counting_input_iterator.cuh" | |
#include "../iterator/tex_ref_input_iterator.cuh" | |
/// Optional outer namespace(s) | |
CUB_NS_PREFIX | |
/// CUB namespace | |
namespace cub { | |
/****************************************************************************** | |
* Tuning policy | |
******************************************************************************/ | |
/** | |
* Parameterizable tuning policy type for AgentSpmv | |
*/ | |
template < | |
int _BLOCK_THREADS, ///< Threads per thread block | |
int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) | |
CacheLoadModifier _ROW_OFFSETS_SEARCH_LOAD_MODIFIER, ///< Cache load modifier for reading CSR row-offsets during search | |
CacheLoadModifier _ROW_OFFSETS_LOAD_MODIFIER, ///< Cache load modifier for reading CSR row-offsets | |
CacheLoadModifier _COLUMN_INDICES_LOAD_MODIFIER, ///< Cache load modifier for reading CSR column-indices | |
CacheLoadModifier _VALUES_LOAD_MODIFIER, ///< Cache load modifier for reading CSR values | |
CacheLoadModifier _VECTOR_VALUES_LOAD_MODIFIER, ///< Cache load modifier for reading vector values | |
bool _DIRECT_LOAD_NONZEROS, ///< Whether to load nonzeros directly from global during sequential merging (vs. pre-staged through shared memory) | |
BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use | |
struct AgentSpmvPolicy | |
{ | |
enum | |
{ | |
BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block | |
ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) | |
DIRECT_LOAD_NONZEROS = _DIRECT_LOAD_NONZEROS, ///< Whether to load nonzeros directly from global during sequential merging (pre-staged through shared memory) | |
}; | |
static const CacheLoadModifier ROW_OFFSETS_SEARCH_LOAD_MODIFIER = _ROW_OFFSETS_SEARCH_LOAD_MODIFIER; ///< Cache load modifier for reading CSR row-offsets | |
static const CacheLoadModifier ROW_OFFSETS_LOAD_MODIFIER = _ROW_OFFSETS_LOAD_MODIFIER; ///< Cache load modifier for reading CSR row-offsets | |
static const CacheLoadModifier COLUMN_INDICES_LOAD_MODIFIER = _COLUMN_INDICES_LOAD_MODIFIER; ///< Cache load modifier for reading CSR column-indices | |
static const CacheLoadModifier VALUES_LOAD_MODIFIER = _VALUES_LOAD_MODIFIER; ///< Cache load modifier for reading CSR values | |
static const CacheLoadModifier VECTOR_VALUES_LOAD_MODIFIER = _VECTOR_VALUES_LOAD_MODIFIER; ///< Cache load modifier for reading vector values | |
static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use | |
}; | |
/****************************************************************************** | |
* Thread block abstractions | |
******************************************************************************/ | |
template < | |
typename ValueT, ///< Matrix and vector value type | |
typename OffsetT> ///< Signed integer type for sequence offsets | |
struct SpmvParams | |
{ | |
ValueT* d_values; ///< Pointer to the array of \p num_nonzeros values of the corresponding nonzero elements of matrix <b>A</b>. | |
OffsetT* d_row_end_offsets; ///< Pointer to the array of \p m offsets demarcating the end of every row in \p d_column_indices and \p d_values | |
OffsetT* d_column_indices; ///< Pointer to the array of \p num_nonzeros column-indices of the corresponding nonzero elements of matrix <b>A</b>. (Indices are zero-valued.) | |
ValueT* d_vector_x; ///< Pointer to the array of \p num_cols values corresponding to the dense input vector <em>x</em> | |
ValueT* d_vector_y; ///< Pointer to the array of \p num_rows values corresponding to the dense output vector <em>y</em> | |
int num_rows; ///< Number of rows of matrix <b>A</b>. | |
int num_cols; ///< Number of columns of matrix <b>A</b>. | |
int num_nonzeros; ///< Number of nonzero elements of matrix <b>A</b>. | |
ValueT alpha; ///< Alpha multiplicand | |
ValueT beta; ///< Beta addend-multiplicand | |
TexRefInputIterator<ValueT, 66778899, OffsetT> t_vector_x; | |
}; | |
/** | |
* \brief AgentSpmv implements a stateful abstraction of CUDA thread blocks for participating in device-wide SpMV. | |
*/ | |
template < | |
typename AgentSpmvPolicyT, ///< Parameterized AgentSpmvPolicy tuning policy type | |
typename ValueT, ///< Matrix and vector value type | |
typename OffsetT, ///< Signed integer type for sequence offsets | |
bool HAS_ALPHA, ///< Whether the input parameter \p alpha is 1 | |
bool HAS_BETA, ///< Whether the input parameter \p beta is 0 | |
int PTX_ARCH = CUB_PTX_ARCH> ///< PTX compute capability | |
struct AgentSpmv | |
{ | |
//--------------------------------------------------------------------- | |
// Types and constants | |
//--------------------------------------------------------------------- | |
/// Constants | |
enum | |
{ | |
BLOCK_THREADS = AgentSpmvPolicyT::BLOCK_THREADS, | |
ITEMS_PER_THREAD = AgentSpmvPolicyT::ITEMS_PER_THREAD, | |
TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, | |
}; | |
/// 2D merge path coordinate type | |
typedef typename CubVector<OffsetT, 2>::Type CoordinateT; | |
/// Input iterator wrapper types (for applying cache modifiers) | |
typedef CacheModifiedInputIterator< | |
AgentSpmvPolicyT::ROW_OFFSETS_SEARCH_LOAD_MODIFIER, | |
OffsetT, | |
OffsetT> | |
RowOffsetsSearchIteratorT; | |
typedef CacheModifiedInputIterator< | |
AgentSpmvPolicyT::ROW_OFFSETS_LOAD_MODIFIER, | |
OffsetT, | |
OffsetT> | |
RowOffsetsIteratorT; | |
typedef CacheModifiedInputIterator< | |
AgentSpmvPolicyT::COLUMN_INDICES_LOAD_MODIFIER, | |
OffsetT, | |
OffsetT> | |
ColumnIndicesIteratorT; | |
typedef CacheModifiedInputIterator< | |
AgentSpmvPolicyT::VALUES_LOAD_MODIFIER, | |
ValueT, | |
OffsetT> | |
ValueIteratorT; | |
typedef CacheModifiedInputIterator< | |
AgentSpmvPolicyT::VECTOR_VALUES_LOAD_MODIFIER, | |
ValueT, | |
OffsetT> | |
VectorValueIteratorT; | |
// Tuple type for scanning (pairs accumulated segment-value with segment-index) | |
typedef KeyValuePair<OffsetT, ValueT> KeyValuePairT; | |
// Reduce-value-by-segment scan operator | |
typedef ReduceByKeyOp<cub::Sum> ReduceBySegmentOpT; | |
// BlockReduce specialization | |
typedef BlockReduce< | |
ValueT, | |
BLOCK_THREADS, | |
BLOCK_REDUCE_WARP_REDUCTIONS> | |
BlockReduceT; | |
// BlockScan specialization | |
typedef BlockScan< | |
KeyValuePairT, | |
BLOCK_THREADS, | |
AgentSpmvPolicyT::SCAN_ALGORITHM> | |
BlockScanT; | |
// BlockScan specialization | |
typedef BlockScan< | |
ValueT, | |
BLOCK_THREADS, | |
AgentSpmvPolicyT::SCAN_ALGORITHM> | |
BlockPrefixSumT; | |
// BlockExchange specialization | |
typedef BlockExchange< | |
ValueT, | |
BLOCK_THREADS, | |
ITEMS_PER_THREAD> | |
BlockExchangeT; | |
/// Merge item type (either a non-zero value or a row-end offset) | |
union MergeItem | |
{ | |
// Value type to pair with index type OffsetT (NullType if loading values directly during merge) | |
typedef typename If<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS, NullType, ValueT>::Type MergeValueT; | |
OffsetT row_end_offset; | |
MergeValueT nonzero; | |
}; | |
/// Shared memory type required by this thread block | |
struct _TempStorage | |
{ | |
CoordinateT tile_coords[2]; | |
union Aliasable | |
{ | |
// Smem needed for tile of merge items | |
MergeItem merge_items[ITEMS_PER_THREAD + TILE_ITEMS + 1]; | |
// Smem needed for block exchange | |
typename BlockExchangeT::TempStorage exchange; | |
// Smem needed for block-wide reduction | |
typename BlockReduceT::TempStorage reduce; | |
// Smem needed for tile scanning | |
typename BlockScanT::TempStorage scan; | |
// Smem needed for tile prefix sum | |
typename BlockPrefixSumT::TempStorage prefix_sum; | |
} aliasable; | |
}; | |
/// Temporary storage type (unionable) | |
struct TempStorage : Uninitialized<_TempStorage> {}; | |
//--------------------------------------------------------------------- | |
// Per-thread fields | |
//--------------------------------------------------------------------- | |
_TempStorage& temp_storage; /// Reference to temp_storage | |
SpmvParams<ValueT, OffsetT>& spmv_params; | |
ValueIteratorT wd_values; ///< Wrapped pointer to the array of \p num_nonzeros values of the corresponding nonzero elements of matrix <b>A</b>. | |
RowOffsetsIteratorT wd_row_end_offsets; ///< Wrapped Pointer to the array of \p m offsets demarcating the end of every row in \p d_column_indices and \p d_values | |
ColumnIndicesIteratorT wd_column_indices; ///< Wrapped Pointer to the array of \p num_nonzeros column-indices of the corresponding nonzero elements of matrix <b>A</b>. (Indices are zero-valued.) | |
VectorValueIteratorT wd_vector_x; ///< Wrapped Pointer to the array of \p num_cols values corresponding to the dense input vector <em>x</em> | |
VectorValueIteratorT wd_vector_y; ///< Wrapped Pointer to the array of \p num_cols values corresponding to the dense input vector <em>x</em> | |
//--------------------------------------------------------------------- | |
// Interface | |
//--------------------------------------------------------------------- | |
/** | |
* Constructor | |
*/ | |
__device__ __forceinline__ AgentSpmv( | |
TempStorage& temp_storage, ///< Reference to temp_storage | |
SpmvParams<ValueT, OffsetT>& spmv_params) ///< SpMV input parameter bundle | |
: | |
temp_storage(temp_storage.Alias()), | |
spmv_params(spmv_params), | |
wd_values(spmv_params.d_values), | |
wd_row_end_offsets(spmv_params.d_row_end_offsets), | |
wd_column_indices(spmv_params.d_column_indices), | |
wd_vector_x(spmv_params.d_vector_x), | |
wd_vector_y(spmv_params.d_vector_y) | |
{} | |
/** | |
* Consume a merge tile, specialized for direct-load of nonzeros | |
*/ | |
__device__ __forceinline__ KeyValuePairT ConsumeTile( | |
int tile_idx, | |
CoordinateT tile_start_coord, | |
CoordinateT tile_end_coord, | |
Int2Type<true> is_direct_load) ///< Marker type indicating whether to load nonzeros directly during path-discovery or beforehand in batch | |
{ | |
int tile_num_rows = tile_end_coord.x - tile_start_coord.x; | |
int tile_num_nonzeros = tile_end_coord.y - tile_start_coord.y; | |
OffsetT* s_tile_row_end_offsets = &temp_storage.aliasable.merge_items[0].row_end_offset; | |
// Gather the row end-offsets for the merge tile into shared memory | |
for (int item = threadIdx.x; item <= tile_num_rows; item += BLOCK_THREADS) | |
{ | |
s_tile_row_end_offsets[item] = wd_row_end_offsets[tile_start_coord.x + item]; | |
} | |
CTA_SYNC(); | |
// Search for the thread's starting coordinate within the merge tile | |
CountingInputIterator<OffsetT> tile_nonzero_indices(tile_start_coord.y); | |
CoordinateT thread_start_coord; | |
MergePathSearch( | |
OffsetT(threadIdx.x * ITEMS_PER_THREAD), // Diagonal | |
s_tile_row_end_offsets, // List A | |
tile_nonzero_indices, // List B | |
tile_num_rows, | |
tile_num_nonzeros, | |
thread_start_coord); | |
CTA_SYNC(); // Perf-sync | |
// Compute the thread's merge path segment | |
CoordinateT thread_current_coord = thread_start_coord; | |
KeyValuePairT scan_segment[ITEMS_PER_THREAD]; | |
ValueT running_total = 0.0; | |
#pragma unroll | |
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
{ | |
OffsetT nonzero_idx = CUB_MIN(tile_nonzero_indices[thread_current_coord.y], spmv_params.num_nonzeros - 1); | |
OffsetT column_idx = wd_column_indices[nonzero_idx]; | |
ValueT value = wd_values[nonzero_idx]; | |
ValueT vector_value = spmv_params.t_vector_x[column_idx]; | |
#if (CUB_PTX_ARCH >= 350) | |
vector_value = wd_vector_x[column_idx]; | |
#endif | |
ValueT nonzero = value * vector_value; | |
OffsetT row_end_offset = s_tile_row_end_offsets[thread_current_coord.x]; | |
if (tile_nonzero_indices[thread_current_coord.y] < row_end_offset) | |
{ | |
// Move down (accumulate) | |
running_total += nonzero; | |
scan_segment[ITEM].value = running_total; | |
scan_segment[ITEM].key = tile_num_rows; | |
++thread_current_coord.y; | |
} | |
else | |
{ | |
// Move right (reset) | |
scan_segment[ITEM].value = running_total; | |
scan_segment[ITEM].key = thread_current_coord.x; | |
running_total = 0.0; | |
++thread_current_coord.x; | |
} | |
} | |
CTA_SYNC(); | |
// Block-wide reduce-value-by-segment | |
KeyValuePairT tile_carry; | |
ReduceBySegmentOpT scan_op; | |
KeyValuePairT scan_item; | |
scan_item.value = running_total; | |
scan_item.key = thread_current_coord.x; | |
BlockScanT(temp_storage.aliasable.scan).ExclusiveScan(scan_item, scan_item, scan_op, tile_carry); | |
if (tile_num_rows > 0) | |
{ | |
if (threadIdx.x == 0) | |
scan_item.key = -1; | |
// Direct scatter | |
#pragma unroll | |
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
{ | |
if (scan_segment[ITEM].key < tile_num_rows) | |
{ | |
if (scan_item.key == scan_segment[ITEM].key) | |
scan_segment[ITEM].value = scan_item.value + scan_segment[ITEM].value; | |
if (HAS_ALPHA) | |
{ | |
scan_segment[ITEM].value *= spmv_params.alpha; | |
} | |
if (HAS_BETA) | |
{ | |
// Update the output vector element | |
ValueT addend = spmv_params.beta * wd_vector_y[tile_start_coord.x + scan_segment[ITEM].key]; | |
scan_segment[ITEM].value += addend; | |
} | |
// Set the output vector element | |
spmv_params.d_vector_y[tile_start_coord.x + scan_segment[ITEM].key] = scan_segment[ITEM].value; | |
} | |
} | |
} | |
// Return the tile's running carry-out | |
return tile_carry; | |
} | |
/** | |
* Consume a merge tile, specialized for indirect load of nonzeros | |
*/ | |
__device__ __forceinline__ KeyValuePairT ConsumeTile( | |
int tile_idx, | |
CoordinateT tile_start_coord, | |
CoordinateT tile_end_coord, | |
Int2Type<false> is_direct_load) ///< Marker type indicating whether to load nonzeros directly during path-discovery or beforehand in batch | |
{ | |
int tile_num_rows = tile_end_coord.x - tile_start_coord.x; | |
int tile_num_nonzeros = tile_end_coord.y - tile_start_coord.y; | |
#if (CUB_PTX_ARCH >= 520) | |
OffsetT* s_tile_row_end_offsets = &temp_storage.aliasable.merge_items[0].row_end_offset; | |
ValueT* s_tile_nonzeros = &temp_storage.aliasable.merge_items[tile_num_rows + ITEMS_PER_THREAD].nonzero; | |
// Gather the nonzeros for the merge tile into shared memory | |
#pragma unroll | |
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
{ | |
int nonzero_idx = threadIdx.x + (ITEM * BLOCK_THREADS); | |
ValueIteratorT a = wd_values + tile_start_coord.y + nonzero_idx; | |
ColumnIndicesIteratorT ci = wd_column_indices + tile_start_coord.y + nonzero_idx; | |
ValueT* s = s_tile_nonzeros + nonzero_idx; | |
if (nonzero_idx < tile_num_nonzeros) | |
{ | |
OffsetT column_idx = *ci; | |
ValueT value = *a; | |
ValueT vector_value = spmv_params.t_vector_x[column_idx]; | |
vector_value = wd_vector_x[column_idx]; | |
ValueT nonzero = value * vector_value; | |
*s = nonzero; | |
} | |
} | |
#else | |
OffsetT* s_tile_row_end_offsets = &temp_storage.aliasable.merge_items[0].row_end_offset; | |
ValueT* s_tile_nonzeros = &temp_storage.aliasable.merge_items[tile_num_rows + ITEMS_PER_THREAD].nonzero; | |
// Gather the nonzeros for the merge tile into shared memory | |
if (tile_num_nonzeros > 0) | |
{ | |
#pragma unroll | |
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
{ | |
int nonzero_idx = threadIdx.x + (ITEM * BLOCK_THREADS); | |
nonzero_idx = CUB_MIN(nonzero_idx, tile_num_nonzeros - 1); | |
OffsetT column_idx = wd_column_indices[tile_start_coord.y + nonzero_idx]; | |
ValueT value = wd_values[tile_start_coord.y + nonzero_idx]; | |
ValueT vector_value = spmv_params.t_vector_x[column_idx]; | |
#if (CUB_PTX_ARCH >= 350) | |
vector_value = wd_vector_x[column_idx]; | |
#endif | |
ValueT nonzero = value * vector_value; | |
s_tile_nonzeros[nonzero_idx] = nonzero; | |
} | |
} | |
#endif | |
// Gather the row end-offsets for the merge tile into shared memory | |
#pragma unroll 1 | |
for (int item = threadIdx.x; item <= tile_num_rows; item += BLOCK_THREADS) | |
{ | |
s_tile_row_end_offsets[item] = wd_row_end_offsets[tile_start_coord.x + item]; | |
} | |
CTA_SYNC(); | |
// Search for the thread's starting coordinate within the merge tile | |
CountingInputIterator<OffsetT> tile_nonzero_indices(tile_start_coord.y); | |
CoordinateT thread_start_coord; | |
MergePathSearch( | |
OffsetT(threadIdx.x * ITEMS_PER_THREAD), // Diagonal | |
s_tile_row_end_offsets, // List A | |
tile_nonzero_indices, // List B | |
tile_num_rows, | |
tile_num_nonzeros, | |
thread_start_coord); | |
CTA_SYNC(); // Perf-sync | |
// Compute the thread's merge path segment | |
CoordinateT thread_current_coord = thread_start_coord; | |
KeyValuePairT scan_segment[ITEMS_PER_THREAD]; | |
ValueT running_total = 0.0; | |
OffsetT row_end_offset = s_tile_row_end_offsets[thread_current_coord.x]; | |
ValueT nonzero = s_tile_nonzeros[thread_current_coord.y]; | |
#pragma unroll | |
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
{ | |
if (tile_nonzero_indices[thread_current_coord.y] < row_end_offset) | |
{ | |
// Move down (accumulate) | |
scan_segment[ITEM].value = nonzero; | |
running_total += nonzero; | |
++thread_current_coord.y; | |
nonzero = s_tile_nonzeros[thread_current_coord.y]; | |
} | |
else | |
{ | |
// Move right (reset) | |
scan_segment[ITEM].value = 0.0; | |
running_total = 0.0; | |
++thread_current_coord.x; | |
row_end_offset = s_tile_row_end_offsets[thread_current_coord.x]; | |
} | |
scan_segment[ITEM].key = thread_current_coord.x; | |
} | |
CTA_SYNC(); | |
// Block-wide reduce-value-by-segment | |
KeyValuePairT tile_carry; | |
ReduceBySegmentOpT scan_op; | |
KeyValuePairT scan_item; | |
scan_item.value = running_total; | |
scan_item.key = thread_current_coord.x; | |
BlockScanT(temp_storage.aliasable.scan).ExclusiveScan(scan_item, scan_item, scan_op, tile_carry); | |
if (threadIdx.x == 0) | |
{ | |
scan_item.key = thread_start_coord.x; | |
scan_item.value = 0.0; | |
} | |
if (tile_num_rows > 0) | |
{ | |
CTA_SYNC(); | |
// Scan downsweep and scatter | |
ValueT* s_partials = &temp_storage.aliasable.merge_items[0].nonzero; | |
if (scan_item.key != scan_segment[0].key) | |
{ | |
s_partials[scan_item.key] = scan_item.value; | |
} | |
else | |
{ | |
scan_segment[0].value += scan_item.value; | |
} | |
#pragma unroll | |
for (int ITEM = 1; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
{ | |
if (scan_segment[ITEM - 1].key != scan_segment[ITEM].key) | |
{ | |
s_partials[scan_segment[ITEM - 1].key] = scan_segment[ITEM - 1].value; | |
} | |
else | |
{ | |
scan_segment[ITEM].value += scan_segment[ITEM - 1].value; | |
} | |
} | |
CTA_SYNC(); | |
#pragma unroll 1 | |
for (int item = threadIdx.x; item < tile_num_rows; item += BLOCK_THREADS) | |
{ | |
spmv_params.d_vector_y[tile_start_coord.x + item] = s_partials[item]; | |
} | |
} | |
// Return the tile's running carry-out | |
return tile_carry; | |
} | |
/** | |
* Consume input tile | |
*/ | |
__device__ __forceinline__ void ConsumeTile( | |
CoordinateT* d_tile_coordinates, ///< [in] Pointer to the temporary array of tile starting coordinates | |
KeyValuePairT* d_tile_carry_pairs, ///< [out] Pointer to the temporary array carry-out dot product row-ids, one per block | |
int num_merge_tiles) ///< [in] Number of merge tiles | |
{ | |
int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index | |
if (tile_idx >= num_merge_tiles) | |
return; | |
// Read our starting coordinates | |
if (threadIdx.x < 2) | |
{ | |
if (d_tile_coordinates == NULL) | |
{ | |
// Search our starting coordinates | |
OffsetT diagonal = (tile_idx + threadIdx.x) * TILE_ITEMS; | |
CoordinateT tile_coord; | |
CountingInputIterator<OffsetT> nonzero_indices(0); | |
// Search the merge path | |
MergePathSearch( | |
diagonal, | |
RowOffsetsSearchIteratorT(spmv_params.d_row_end_offsets), | |
nonzero_indices, | |
spmv_params.num_rows, | |
spmv_params.num_nonzeros, | |
tile_coord); | |
temp_storage.tile_coords[threadIdx.x] = tile_coord; | |
} | |
else | |
{ | |
temp_storage.tile_coords[threadIdx.x] = d_tile_coordinates[tile_idx + threadIdx.x]; | |
} | |
} | |
CTA_SYNC(); | |
CoordinateT tile_start_coord = temp_storage.tile_coords[0]; | |
CoordinateT tile_end_coord = temp_storage.tile_coords[1]; | |
// Consume multi-segment tile | |
KeyValuePairT tile_carry = ConsumeTile( | |
tile_idx, | |
tile_start_coord, | |
tile_end_coord, | |
Int2Type<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS>()); | |
// Output the tile's carry-out | |
if (threadIdx.x == 0) | |
{ | |
if (HAS_ALPHA) | |
tile_carry.value *= spmv_params.alpha; | |
tile_carry.key += tile_start_coord.x; | |
d_tile_carry_pairs[tile_idx] = tile_carry; | |
} | |
} | |
}; | |
} // CUB namespace | |
CUB_NS_POSTFIX // Optional outer namespace(s) | |