Spaces:
Runtime error
Runtime error
| /****************************************************************************** | |
| * Copyright (c) 2011, Duane Merrill. All rights reserved. | |
| * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. | |
| * | |
| * Redistribution and use in source and binary forms, with or without | |
| * modification, are permitted provided that the following conditions are met: | |
| * * Redistributions of source code must retain the above copyright | |
| * notice, this list of conditions and the following disclaimer. | |
| * * Redistributions in binary form must reproduce the above copyright | |
| * notice, this list of conditions and the following disclaimer in the | |
| * documentation and/or other materials provided with the distribution. | |
| * * Neither the name of the NVIDIA CORPORATION nor the | |
| * names of its contributors may be used to endorse or promote products | |
| * derived from this software without specific prior written permission. | |
| * | |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |
| * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | |
| * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | |
| * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | |
| * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | |
| * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | |
| * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
| * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| * | |
| ******************************************************************************/ | |
| /** | |
| * \file | |
| * cub::AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide select. | |
| */ | |
| #pragma once | |
| #include <iterator> | |
| #include "single_pass_scan_operators.cuh" | |
| #include "../block/block_load.cuh" | |
| #include "../block/block_store.cuh" | |
| #include "../block/block_scan.cuh" | |
| #include "../block/block_exchange.cuh" | |
| #include "../block/block_discontinuity.cuh" | |
| #include "../config.cuh" | |
| #include "../grid/grid_queue.cuh" | |
| #include "../iterator/cache_modified_input_iterator.cuh" | |
| /// Optional outer namespace(s) | |
| CUB_NS_PREFIX | |
| /// CUB namespace | |
| namespace cub { | |
| /****************************************************************************** | |
| * Tuning policy types | |
| ******************************************************************************/ | |
| /** | |
| * Parameterizable tuning policy type for AgentSelectIf | |
| */ | |
| template < | |
| int _BLOCK_THREADS, ///< Threads per thread block | |
| int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) | |
| BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use | |
| CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements | |
| BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use | |
| struct AgentSelectIfPolicy | |
| { | |
| enum | |
| { | |
| BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block | |
| ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) | |
| }; | |
| static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use | |
| static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements | |
| static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use | |
| }; | |
| /****************************************************************************** | |
| * Thread block abstractions | |
| ******************************************************************************/ | |
| /** | |
| * \brief AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide selection | |
| * | |
| * Performs functor-based selection if SelectOpT functor type != NullType | |
| * Otherwise performs flag-based selection if FlagsInputIterator's value type != NullType | |
| * Otherwise performs discontinuity selection (keep unique) | |
| */ | |
| template < | |
| typename AgentSelectIfPolicyT, ///< Parameterized AgentSelectIfPolicy tuning policy type | |
| typename InputIteratorT, ///< Random-access input iterator type for selection items | |
| typename FlagsInputIteratorT, ///< Random-access input iterator type for selections (NullType* if a selection functor or discontinuity flagging is to be used for selection) | |
| typename SelectedOutputIteratorT, ///< Random-access input iterator type for selection_flags items | |
| typename SelectOpT, ///< Selection operator type (NullType if selections or discontinuity flagging is to be used for selection) | |
| typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selections is to be used for selection) | |
| typename OffsetT, ///< Signed integer type for global offsets | |
| bool KEEP_REJECTS> ///< Whether or not we push rejected items to the back of the output | |
| struct AgentSelectIf | |
| { | |
| //--------------------------------------------------------------------- | |
| // Types and constants | |
| //--------------------------------------------------------------------- | |
| // The input value type | |
| typedef typename std::iterator_traits<InputIteratorT>::value_type InputT; | |
| // The output value type | |
| typedef typename If<(Equals<typename std::iterator_traits<SelectedOutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? | |
| typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type, | |
| typename std::iterator_traits<SelectedOutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type | |
| // The flag value type | |
| typedef typename std::iterator_traits<FlagsInputIteratorT>::value_type FlagT; | |
| // Tile status descriptor interface type | |
| typedef ScanTileState<OffsetT> ScanTileStateT; | |
| // Constants | |
| enum | |
| { | |
| USE_SELECT_OP, | |
| USE_SELECT_FLAGS, | |
| USE_DISCONTINUITY, | |
| BLOCK_THREADS = AgentSelectIfPolicyT::BLOCK_THREADS, | |
| ITEMS_PER_THREAD = AgentSelectIfPolicyT::ITEMS_PER_THREAD, | |
| TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, | |
| TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1), | |
| SELECT_METHOD = (!Equals<SelectOpT, NullType>::VALUE) ? | |
| USE_SELECT_OP : | |
| (!Equals<FlagT, NullType>::VALUE) ? | |
| USE_SELECT_FLAGS : | |
| USE_DISCONTINUITY | |
| }; | |
| // Cache-modified Input iterator wrapper type (for applying cache modifier) for items | |
| typedef typename If<IsPointer<InputIteratorT>::VALUE, | |
| CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator | |
| InputIteratorT>::Type // Directly use the supplied input iterator type | |
| WrappedInputIteratorT; | |
| // Cache-modified Input iterator wrapper type (for applying cache modifier) for values | |
| typedef typename If<IsPointer<FlagsInputIteratorT>::VALUE, | |
| CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator | |
| FlagsInputIteratorT>::Type // Directly use the supplied input iterator type | |
| WrappedFlagsInputIteratorT; | |
| // Parameterized BlockLoad type for input data | |
| typedef BlockLoad< | |
| OutputT, | |
| BLOCK_THREADS, | |
| ITEMS_PER_THREAD, | |
| AgentSelectIfPolicyT::LOAD_ALGORITHM> | |
| BlockLoadT; | |
| // Parameterized BlockLoad type for flags | |
| typedef BlockLoad< | |
| FlagT, | |
| BLOCK_THREADS, | |
| ITEMS_PER_THREAD, | |
| AgentSelectIfPolicyT::LOAD_ALGORITHM> | |
| BlockLoadFlags; | |
| // Parameterized BlockDiscontinuity type for items | |
| typedef BlockDiscontinuity< | |
| OutputT, | |
| BLOCK_THREADS> | |
| BlockDiscontinuityT; | |
| // Parameterized BlockScan type | |
| typedef BlockScan< | |
| OffsetT, | |
| BLOCK_THREADS, | |
| AgentSelectIfPolicyT::SCAN_ALGORITHM> | |
| BlockScanT; | |
| // Callback type for obtaining tile prefix during block scan | |
| typedef TilePrefixCallbackOp< | |
| OffsetT, | |
| cub::Sum, | |
| ScanTileStateT> | |
| TilePrefixCallbackOpT; | |
| // Item exchange type | |
| typedef OutputT ItemExchangeT[TILE_ITEMS]; | |
| // Shared memory type for this thread block | |
| union _TempStorage | |
| { | |
| struct | |
| { | |
| typename BlockScanT::TempStorage scan; // Smem needed for tile scanning | |
| typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback | |
| typename BlockDiscontinuityT::TempStorage discontinuity; // Smem needed for discontinuity detection | |
| }; | |
| // Smem needed for loading items | |
| typename BlockLoadT::TempStorage load_items; | |
| // Smem needed for loading values | |
| typename BlockLoadFlags::TempStorage load_flags; | |
| // Smem needed for compacting items (allows non POD items in this union) | |
| Uninitialized<ItemExchangeT> raw_exchange; | |
| }; | |
| // Alias wrapper allowing storage to be unioned | |
| struct TempStorage : Uninitialized<_TempStorage> {}; | |
| //--------------------------------------------------------------------- | |
| // Per-thread fields | |
| //--------------------------------------------------------------------- | |
| _TempStorage& temp_storage; ///< Reference to temp_storage | |
| WrappedInputIteratorT d_in; ///< Input items | |
| SelectedOutputIteratorT d_selected_out; ///< Unique output items | |
| WrappedFlagsInputIteratorT d_flags_in; ///< Input selection flags (if applicable) | |
| InequalityWrapper<EqualityOpT> inequality_op; ///< T inequality operator | |
| SelectOpT select_op; ///< Selection operator | |
| OffsetT num_items; ///< Total number of input items | |
| //--------------------------------------------------------------------- | |
| // Constructor | |
| //--------------------------------------------------------------------- | |
| // Constructor | |
| __device__ __forceinline__ | |
| AgentSelectIf( | |
| TempStorage &temp_storage, ///< Reference to temp_storage | |
| InputIteratorT d_in, ///< Input data | |
| FlagsInputIteratorT d_flags_in, ///< Input selection flags (if applicable) | |
| SelectedOutputIteratorT d_selected_out, ///< Output data | |
| SelectOpT select_op, ///< Selection operator | |
| EqualityOpT equality_op, ///< Equality operator | |
| OffsetT num_items) ///< Total number of input items | |
| : | |
| temp_storage(temp_storage.Alias()), | |
| d_in(d_in), | |
| d_flags_in(d_flags_in), | |
| d_selected_out(d_selected_out), | |
| select_op(select_op), | |
| inequality_op(equality_op), | |
| num_items(num_items) | |
| {} | |
| //--------------------------------------------------------------------- | |
| // Utility methods for initializing the selections | |
| //--------------------------------------------------------------------- | |
| /** | |
| * Initialize selections (specialized for selection operator) | |
| */ | |
| template <bool IS_FIRST_TILE, bool IS_LAST_TILE> | |
| __device__ __forceinline__ void InitializeSelections( | |
| OffsetT /*tile_offset*/, | |
| OffsetT num_tile_items, | |
| OutputT (&items)[ITEMS_PER_THREAD], | |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], | |
| Int2Type<USE_SELECT_OP> /*select_method*/) | |
| { | |
| #pragma unroll | |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
| { | |
| // Out-of-bounds items are selection_flags | |
| selection_flags[ITEM] = 1; | |
| if (!IS_LAST_TILE || (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM < num_tile_items)) | |
| selection_flags[ITEM] = select_op(items[ITEM]); | |
| } | |
| } | |
| /** | |
| * Initialize selections (specialized for valid flags) | |
| */ | |
| template <bool IS_FIRST_TILE, bool IS_LAST_TILE> | |
| __device__ __forceinline__ void InitializeSelections( | |
| OffsetT tile_offset, | |
| OffsetT num_tile_items, | |
| OutputT (&/*items*/)[ITEMS_PER_THREAD], | |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], | |
| Int2Type<USE_SELECT_FLAGS> /*select_method*/) | |
| { | |
| CTA_SYNC(); | |
| FlagT flags[ITEMS_PER_THREAD]; | |
| if (IS_LAST_TILE) | |
| { | |
| // Out-of-bounds items are selection_flags | |
| BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags, num_tile_items, 1); | |
| } | |
| else | |
| { | |
| BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags); | |
| } | |
| // Convert flag type to selection_flags type | |
| #pragma unroll | |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
| { | |
| selection_flags[ITEM] = flags[ITEM]; | |
| } | |
| } | |
| /** | |
| * Initialize selections (specialized for discontinuity detection) | |
| */ | |
| template <bool IS_FIRST_TILE, bool IS_LAST_TILE> | |
| __device__ __forceinline__ void InitializeSelections( | |
| OffsetT tile_offset, | |
| OffsetT num_tile_items, | |
| OutputT (&items)[ITEMS_PER_THREAD], | |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], | |
| Int2Type<USE_DISCONTINUITY> /*select_method*/) | |
| { | |
| if (IS_FIRST_TILE) | |
| { | |
| CTA_SYNC(); | |
| // Set head selection_flags. First tile sets the first flag for the first item | |
| BlockDiscontinuityT(temp_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op); | |
| } | |
| else | |
| { | |
| OutputT tile_predecessor; | |
| if (threadIdx.x == 0) | |
| tile_predecessor = d_in[tile_offset - 1]; | |
| CTA_SYNC(); | |
| BlockDiscontinuityT(temp_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op, tile_predecessor); | |
| } | |
| // Set selection flags for out-of-bounds items | |
| #pragma unroll | |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
| { | |
| // Set selection_flags for out-of-bounds items | |
| if ((IS_LAST_TILE) && (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM >= num_tile_items)) | |
| selection_flags[ITEM] = 1; | |
| } | |
| } | |
| //--------------------------------------------------------------------- | |
| // Scatter utility methods | |
| //--------------------------------------------------------------------- | |
| /** | |
| * Scatter flagged items to output offsets (specialized for direct scattering) | |
| */ | |
| template <bool IS_LAST_TILE, bool IS_FIRST_TILE> | |
| __device__ __forceinline__ void ScatterDirect( | |
| OutputT (&items)[ITEMS_PER_THREAD], | |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], | |
| OffsetT (&selection_indices)[ITEMS_PER_THREAD], | |
| OffsetT num_selections) | |
| { | |
| // Scatter flagged items | |
| #pragma unroll | |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
| { | |
| if (selection_flags[ITEM]) | |
| { | |
| if ((!IS_LAST_TILE) || selection_indices[ITEM] < num_selections) | |
| { | |
| d_selected_out[selection_indices[ITEM]] = items[ITEM]; | |
| } | |
| } | |
| } | |
| } | |
| /** | |
| * Scatter flagged items to output offsets (specialized for two-phase scattering) | |
| */ | |
| template <bool IS_LAST_TILE, bool IS_FIRST_TILE> | |
| __device__ __forceinline__ void ScatterTwoPhase( | |
| OutputT (&items)[ITEMS_PER_THREAD], | |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], | |
| OffsetT (&selection_indices)[ITEMS_PER_THREAD], | |
| int /*num_tile_items*/, ///< Number of valid items in this tile | |
| int num_tile_selections, ///< Number of selections in this tile | |
| OffsetT num_selections_prefix, ///< Total number of selections prior to this tile | |
| OffsetT /*num_rejected_prefix*/, ///< Total number of rejections prior to this tile | |
| Int2Type<false> /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition | |
| { | |
| CTA_SYNC(); | |
| // Compact and scatter items | |
| #pragma unroll | |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
| { | |
| int local_scatter_offset = selection_indices[ITEM] - num_selections_prefix; | |
| if (selection_flags[ITEM]) | |
| { | |
| temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM]; | |
| } | |
| } | |
| CTA_SYNC(); | |
| for (int item = threadIdx.x; item < num_tile_selections; item += BLOCK_THREADS) | |
| { | |
| d_selected_out[num_selections_prefix + item] = temp_storage.raw_exchange.Alias()[item]; | |
| } | |
| } | |
| /** | |
| * Scatter flagged items to output offsets (specialized for two-phase scattering) | |
| */ | |
| template <bool IS_LAST_TILE, bool IS_FIRST_TILE> | |
| __device__ __forceinline__ void ScatterTwoPhase( | |
| OutputT (&items)[ITEMS_PER_THREAD], | |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], | |
| OffsetT (&selection_indices)[ITEMS_PER_THREAD], | |
| int num_tile_items, ///< Number of valid items in this tile | |
| int num_tile_selections, ///< Number of selections in this tile | |
| OffsetT num_selections_prefix, ///< Total number of selections prior to this tile | |
| OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile | |
| Int2Type<true> /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition | |
| { | |
| CTA_SYNC(); | |
| int tile_num_rejections = num_tile_items - num_tile_selections; | |
| // Scatter items to shared memory (rejections first) | |
| #pragma unroll | |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
| { | |
| int item_idx = (threadIdx.x * ITEMS_PER_THREAD) + ITEM; | |
| int local_selection_idx = selection_indices[ITEM] - num_selections_prefix; | |
| int local_rejection_idx = item_idx - local_selection_idx; | |
| int local_scatter_offset = (selection_flags[ITEM]) ? | |
| tile_num_rejections + local_selection_idx : | |
| local_rejection_idx; | |
| temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM]; | |
| } | |
| CTA_SYNC(); | |
| // Gather items from shared memory and scatter to global | |
| #pragma unroll | |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) | |
| { | |
| int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x; | |
| int rejection_idx = item_idx; | |
| int selection_idx = item_idx - tile_num_rejections; | |
| OffsetT scatter_offset = (item_idx < tile_num_rejections) ? | |
| num_items - num_rejected_prefix - rejection_idx - 1 : | |
| num_selections_prefix + selection_idx; | |
| OutputT item = temp_storage.raw_exchange.Alias()[item_idx]; | |
| if (!IS_LAST_TILE || (item_idx < num_tile_items)) | |
| { | |
| d_selected_out[scatter_offset] = item; | |
| } | |
| } | |
| } | |
| /** | |
| * Scatter flagged items | |
| */ | |
| template <bool IS_LAST_TILE, bool IS_FIRST_TILE> | |
| __device__ __forceinline__ void Scatter( | |
| OutputT (&items)[ITEMS_PER_THREAD], | |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], | |
| OffsetT (&selection_indices)[ITEMS_PER_THREAD], | |
| int num_tile_items, ///< Number of valid items in this tile | |
| int num_tile_selections, ///< Number of selections in this tile | |
| OffsetT num_selections_prefix, ///< Total number of selections prior to this tile | |
| OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile | |
| OffsetT num_selections) ///< Total number of selections including this tile | |
| { | |
| // Do a two-phase scatter if (a) keeping both partitions or (b) two-phase is enabled and the average number of selection_flags items per thread is greater than one | |
| if (KEEP_REJECTS || (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS))) | |
| { | |
| ScatterTwoPhase<IS_LAST_TILE, IS_FIRST_TILE>( | |
| items, | |
| selection_flags, | |
| selection_indices, | |
| num_tile_items, | |
| num_tile_selections, | |
| num_selections_prefix, | |
| num_rejected_prefix, | |
| Int2Type<KEEP_REJECTS>()); | |
| } | |
| else | |
| { | |
| ScatterDirect<IS_LAST_TILE, IS_FIRST_TILE>( | |
| items, | |
| selection_flags, | |
| selection_indices, | |
| num_selections); | |
| } | |
| } | |
| //--------------------------------------------------------------------- | |
| // Cooperatively scan a device-wide sequence of tiles with other CTAs | |
| //--------------------------------------------------------------------- | |
| /** | |
| * Process first tile of input (dynamic chained scan). Returns the running count of selections (including this tile) | |
| */ | |
| template <bool IS_LAST_TILE> | |
| __device__ __forceinline__ OffsetT ConsumeFirstTile( | |
| int num_tile_items, ///< Number of input items comprising this tile | |
| OffsetT tile_offset, ///< Tile offset | |
| ScanTileStateT& tile_state) ///< Global tile state descriptor | |
| { | |
| OutputT items[ITEMS_PER_THREAD]; | |
| OffsetT selection_flags[ITEMS_PER_THREAD]; | |
| OffsetT selection_indices[ITEMS_PER_THREAD]; | |
| // Load items | |
| if (IS_LAST_TILE) | |
| BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items); | |
| else | |
| BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items); | |
| // Initialize selection_flags | |
| InitializeSelections<true, IS_LAST_TILE>( | |
| tile_offset, | |
| num_tile_items, | |
| items, | |
| selection_flags, | |
| Int2Type<SELECT_METHOD>()); | |
| CTA_SYNC(); | |
| // Exclusive scan of selection_flags | |
| OffsetT num_tile_selections; | |
| BlockScanT(temp_storage.scan).ExclusiveSum(selection_flags, selection_indices, num_tile_selections); | |
| if (threadIdx.x == 0) | |
| { | |
| // Update tile status if this is not the last tile | |
| if (!IS_LAST_TILE) | |
| tile_state.SetInclusive(0, num_tile_selections); | |
| } | |
| // Discount any out-of-bounds selections | |
| if (IS_LAST_TILE) | |
| num_tile_selections -= (TILE_ITEMS - num_tile_items); | |
| // Scatter flagged items | |
| Scatter<IS_LAST_TILE, true>( | |
| items, | |
| selection_flags, | |
| selection_indices, | |
| num_tile_items, | |
| num_tile_selections, | |
| 0, | |
| 0, | |
| num_tile_selections); | |
| return num_tile_selections; | |
| } | |
| /** | |
| * Process subsequent tile of input (dynamic chained scan). Returns the running count of selections (including this tile) | |
| */ | |
| template <bool IS_LAST_TILE> | |
| __device__ __forceinline__ OffsetT ConsumeSubsequentTile( | |
| int num_tile_items, ///< Number of input items comprising this tile | |
| int tile_idx, ///< Tile index | |
| OffsetT tile_offset, ///< Tile offset | |
| ScanTileStateT& tile_state) ///< Global tile state descriptor | |
| { | |
| OutputT items[ITEMS_PER_THREAD]; | |
| OffsetT selection_flags[ITEMS_PER_THREAD]; | |
| OffsetT selection_indices[ITEMS_PER_THREAD]; | |
| // Load items | |
| if (IS_LAST_TILE) | |
| BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items); | |
| else | |
| BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items); | |
| // Initialize selection_flags | |
| InitializeSelections<false, IS_LAST_TILE>( | |
| tile_offset, | |
| num_tile_items, | |
| items, | |
| selection_flags, | |
| Int2Type<SELECT_METHOD>()); | |
| CTA_SYNC(); | |
| // Exclusive scan of values and selection_flags | |
| TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, cub::Sum(), tile_idx); | |
| BlockScanT(temp_storage.scan).ExclusiveSum(selection_flags, selection_indices, prefix_op); | |
| OffsetT num_tile_selections = prefix_op.GetBlockAggregate(); | |
| OffsetT num_selections = prefix_op.GetInclusivePrefix(); | |
| OffsetT num_selections_prefix = prefix_op.GetExclusivePrefix(); | |
| OffsetT num_rejected_prefix = (tile_idx * TILE_ITEMS) - num_selections_prefix; | |
| // Discount any out-of-bounds selections | |
| if (IS_LAST_TILE) | |
| { | |
| int num_discount = TILE_ITEMS - num_tile_items; | |
| num_selections -= num_discount; | |
| num_tile_selections -= num_discount; | |
| } | |
| // Scatter flagged items | |
| Scatter<IS_LAST_TILE, false>( | |
| items, | |
| selection_flags, | |
| selection_indices, | |
| num_tile_items, | |
| num_tile_selections, | |
| num_selections_prefix, | |
| num_rejected_prefix, | |
| num_selections); | |
| return num_selections; | |
| } | |
| /** | |
| * Process a tile of input | |
| */ | |
| template <bool IS_LAST_TILE> | |
| __device__ __forceinline__ OffsetT ConsumeTile( | |
| int num_tile_items, ///< Number of input items comprising this tile | |
| int tile_idx, ///< Tile index | |
| OffsetT tile_offset, ///< Tile offset | |
| ScanTileStateT& tile_state) ///< Global tile state descriptor | |
| { | |
| OffsetT num_selections; | |
| if (tile_idx == 0) | |
| { | |
| num_selections = ConsumeFirstTile<IS_LAST_TILE>(num_tile_items, tile_offset, tile_state); | |
| } | |
| else | |
| { | |
| num_selections = ConsumeSubsequentTile<IS_LAST_TILE>(num_tile_items, tile_idx, tile_offset, tile_state); | |
| } | |
| return num_selections; | |
| } | |
| /** | |
| * Scan tiles of items as part of a dynamic chained scan | |
| */ | |
| template <typename NumSelectedIteratorT> ///< Output iterator type for recording number of items selection_flags | |
| __device__ __forceinline__ void ConsumeRange( | |
| int num_tiles, ///< Total number of input tiles | |
| ScanTileStateT& tile_state, ///< Global tile state descriptor | |
| NumSelectedIteratorT d_num_selected_out) ///< Output total number selection_flags | |
| { | |
| // Blocks are launched in increasing order, so just assign one tile per block | |
| int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index | |
| OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile | |
| if (tile_idx < num_tiles - 1) | |
| { | |
| // Not the last tile (full) | |
| ConsumeTile<false>(TILE_ITEMS, tile_idx, tile_offset, tile_state); | |
| } | |
| else | |
| { | |
| // The last tile (possibly partially-full) | |
| OffsetT num_remaining = num_items - tile_offset; | |
| OffsetT num_selections = ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state); | |
| if (threadIdx.x == 0) | |
| { | |
| // Output the total number of items selection_flags | |
| *d_num_selected_out = num_selections; | |
| } | |
| } | |
| } | |
| }; | |
| } // CUB namespace | |
| CUB_NS_POSTFIX // Optional outer namespace(s) | |