Skip to content

Commit 4a0f7eb

Browse files
committed
Modify CUDA backend to use ::cub instead of ::thrust::cuda_cub::cub.
1 parent 26836e2 commit 4a0f7eb

File tree

15 files changed

+200
-155
lines changed

15 files changed

+200
-155
lines changed

.dependencies/cub

Submodule cub updated from 2b5c0cd to 464a90b

thrust/system/cuda/detail/adjacent_difference.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
#include <thrust/detail/cstdint.h>
3434
#include <thrust/detail/temporary_array.h>
3535
#include <thrust/system/cuda/detail/util.h>
36-
#include <thrust/system/cuda/detail/cub/device/device_select.cuh>
37-
#include <thrust/system/cuda/detail/cub/block/block_adjacent_difference.cuh>
36+
#include <cub/device/device_select.cuh>
37+
#include <cub/block/block_adjacent_difference.cuh>
3838
#include <thrust/system/cuda/detail/core/agent_launcher.h>
3939
#include <thrust/system/cuda/detail/par_to_seq.h>
4040
#include <thrust/functional.h>
@@ -100,7 +100,7 @@ namespace __adjacent_difference {
100100

101101
template<class Arch, class T>
102102
struct Tuning;
103-
103+
104104
template <class T>
105105
struct Tuning<sm30, T>
106106
{
@@ -520,7 +520,7 @@ adjacent_difference(execution_policy<Derived> &policy,
520520
}
521521

522522
return ret;
523-
}
523+
}
524524

525525
template <class Derived,
526526
class InputIt,

thrust/system/cuda/detail/async/reduce.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ auto async_reduce_n(
8282

8383
size_t tmp_size = 0;
8484
thrust::cuda_cub::throw_on_error(
85-
thrust::cuda_cub::cub::DeviceReduce::Reduce(
85+
cub::DeviceReduce::Reduce(
8686
nullptr
8787
, tmp_size
8888
, first
@@ -164,7 +164,7 @@ auto async_reduce_n(
164164
// Run reduction.
165165

166166
thrust::cuda_cub::throw_on_error(
167-
thrust::cuda_cub::cub::DeviceReduce::Reduce(
167+
cub::DeviceReduce::Reduce(
168168
tmp_ptr
169169
, tmp_size
170170
, first
@@ -237,7 +237,7 @@ auto async_reduce_into_n(
237237

238238
size_t tmp_size = 0;
239239
thrust::cuda_cub::throw_on_error(
240-
thrust::cuda_cub::cub::DeviceReduce::Reduce(
240+
cub::DeviceReduce::Reduce(
241241
nullptr
242242
, tmp_size
243243
, first
@@ -301,7 +301,7 @@ auto async_reduce_into_n(
301301
// Run reduction.
302302

303303
thrust::cuda_cub::throw_on_error(
304-
thrust::cuda_cub::cub::DeviceReduce::Reduce(
304+
cub::DeviceReduce::Reduce(
305305
tmp_ptr
306306
, tmp_size
307307
, first
@@ -350,5 +350,5 @@ THRUST_END_NS
350350

351351
#endif // THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC
352352

353-
#endif
353+
#endif
354354

thrust/system/cuda/detail/async/sort.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ auto async_stable_sort_n(
8787

8888
auto const device_buffer_ptr = device_buffer.get();
8989

90-
// Synthesize a suitable new execution policy, because we don't want to
90+
// Synthesize a suitable new execution policy, because we don't want to
9191
// try and extract twice from the one we were passed.
9292
typename remove_cvref_t<decltype(policy)>::tag_type tag_policy{};
9393

@@ -294,15 +294,15 @@ typename std::enable_if<
294294
, cudaError_t
295295
>::type
296296
invoke_radix_sort(
297-
cudaStream_t stream
298-
, void* tmp_ptr
299-
, std::size_t& tmp_size
300-
, thrust::cuda_cub::cub::DoubleBuffer<T>& keys
301-
, Size& n
297+
cudaStream_t stream
298+
, void* tmp_ptr
299+
, std::size_t& tmp_size
300+
, cub::DoubleBuffer<T>& keys
301+
, Size& n
302302
, StrictWeakOrdering
303303
)
304304
{
305-
return thrust::cuda_cub::cub::DeviceRadixSort::SortKeys(
305+
return cub::DeviceRadixSort::SortKeys(
306306
tmp_ptr
307307
, tmp_size
308308
, keys
@@ -321,15 +321,15 @@ typename std::enable_if<
321321
, cudaError_t
322322
>::type
323323
invoke_radix_sort(
324-
cudaStream_t stream
325-
, void* tmp_ptr
326-
, std::size_t& tmp_size
327-
, thrust::cuda_cub::cub::DoubleBuffer<T>& keys
328-
, Size& n
324+
cudaStream_t stream
325+
, void* tmp_ptr
326+
, std::size_t& tmp_size
327+
, cub::DoubleBuffer<T>& keys
328+
, Size& n
329329
, StrictWeakOrdering
330330
)
331331
{
332-
return thrust::cuda_cub::cub::DeviceRadixSort::SortKeysDescending(
332+
return cub::DeviceRadixSort::SortKeysDescending(
333333
tmp_ptr
334334
, tmp_size
335335
, keys
@@ -372,7 +372,7 @@ auto async_stable_sort_n(
372372

373373
unique_eager_event e;
374374

375-
thrust::cuda_cub::cub::DoubleBuffer<T> keys(
375+
cub::DoubleBuffer<T> keys(
376376
raw_pointer_cast(&*first), nullptr
377377
);
378378

@@ -476,7 +476,7 @@ auto async_stable_sort_n(
476476
)>::value
477477
));
478478

479-
// Synthesize a suitable new execution policy, because we don't want to
479+
// Synthesize a suitable new execution policy, because we don't want to
480480
// try and extract twice from the one we were passed.
481481
typename remove_cvref_t<decltype(policy)>::tag_type tag_policy{};
482482

thrust/system/cuda/detail/copy_if.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include <thrust/detail/cstdint.h>
3434
#include <thrust/detail/temporary_array.h>
3535
#include <thrust/system/cuda/detail/util.h>
36-
#include <thrust/system/cuda/detail/cub/device/device_select.cuh>
36+
#include <cub/device/device_select.cuh>
3737
#include <thrust/system/cuda/detail/core/agent_launcher.h>
3838
#include <thrust/system/cuda/detail/core/util.h>
3939
#include <thrust/system/cuda/detail/par_to_seq.h>
@@ -89,7 +89,7 @@ namespace __copy_if {
8989

9090
template<class, class>
9191
struct Tuning;
92-
92+
9393
template<class T>
9494
struct Tuning<sm52, T>
9595
{
@@ -109,7 +109,7 @@ namespace __copy_if {
109109
cub::BLOCK_SCAN_WARP_SCANS>
110110
type;
111111
}; // Tuning<350>
112-
112+
113113

114114
template<class T>
115115
struct Tuning<sm35, T>
@@ -130,7 +130,7 @@ namespace __copy_if {
130130
cub::BLOCK_SCAN_WARP_SCANS>
131131
type;
132132
}; // Tuning<350>
133-
133+
134134
template<class T>
135135
struct Tuning<sm30, T>
136136
{
@@ -150,7 +150,7 @@ namespace __copy_if {
150150
cub::BLOCK_SCAN_WARP_SCANS>
151151
type;
152152
}; // Tuning<300>
153-
153+
154154
struct no_stencil_tag_ {};
155155
typedef no_stencil_tag_* no_stencil_tag;
156156
template <class ItemsIt,
@@ -206,7 +206,7 @@ namespace __copy_if {
206206
core::uninitialized_array<item_type, PtxPlan::ITEMS_PER_TILE> raw_exchange;
207207
}; // union TempStorage
208208
}; // struct PtxPlan
209-
209+
210210
typedef typename core::specialize_plan_msvc10_war<PtxPlan>::type::type ptx_plan;
211211

212212
typedef typename ptx_plan::ItemsLoadIt ItemsLoadIt;
@@ -224,7 +224,7 @@ namespace __copy_if {
224224
ITEMS_PER_THREAD = ptx_plan::ITEMS_PER_THREAD,
225225
ITEMS_PER_TILE = ptx_plan::ITEMS_PER_TILE
226226
};
227-
227+
228228
struct impl
229229
{
230230
//---------------------------------------------------------------------
@@ -238,7 +238,7 @@ namespace __copy_if {
238238
OutputIt output_it;
239239
Predicate predicate;
240240
Size num_items;
241-
241+
242242
//------------------------------------------
243243
// scatter results to memory
244244
//------------------------------------------
@@ -272,7 +272,7 @@ namespace __copy_if {
272272
output_it[num_selections_prefix + item] = storage.raw_exchange[item];
273273
}
274274
} // func scatter
275-
275+
276276
//------------------------------------------
277277
// specialize predicate on different types
278278
//------------------------------------------
@@ -357,11 +357,11 @@ namespace __copy_if {
357357
}
358358
}
359359
}
360-
360+
361361
//------------------------------------------
362362
// consume tiles
363363
//------------------------------------------
364-
364+
365365
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
366366
Size THRUST_DEVICE_FUNCTION
367367
consume_tile_impl(int num_tile_items,
@@ -501,7 +501,7 @@ namespace __copy_if {
501501
//---------------------------------------------------------------------
502502
// Constructor
503503
//---------------------------------------------------------------------
504-
504+
505505
THRUST_DEVICE_FUNCTION impl(TempStorage & storage_,
506506
ScanTileState & tile_state_,
507507
ItemsIt items_it,
@@ -578,7 +578,7 @@ namespace __copy_if {
578578
template <class Arch>
579579
struct PtxPlan : PtxPolicy<128> {};
580580
typedef core::specialize_plan<PtxPlan> ptx_plan;
581-
581+
582582
//---------------------------------------------------------------------
583583
// Agent entry point
584584
//---------------------------------------------------------------------
@@ -648,19 +648,19 @@ namespace __copy_if {
648648
cudaError_t status = cudaSuccess;
649649
if (num_items == 0)
650650
return status;
651-
651+
652652
size_t allocation_sizes[2] = {0, vshmem_size};
653653
status = ScanTileState::AllocationSize(static_cast<int>(num_tiles), allocation_sizes[0]);
654654
CUDA_CUB_RET_IF_FAIL(status);
655-
655+
656656

657657
void* allocations[2] = {NULL, NULL};
658658
status = cub::AliasTemporaries(d_temp_storage,
659659
temp_storage_bytes,
660660
allocations,
661661
allocation_sizes);
662662
CUDA_CUB_RET_IF_FAIL(status);
663-
663+
664664

665665
if (d_temp_storage == NULL)
666666
{

thrust/system/cuda/detail/core/util.h

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
#include <thrust/type_traits/is_contiguous_iterator.h>
3333
#include <thrust/detail/raw_pointer_cast.h>
3434
#include <thrust/system/cuda/detail/util.h>
35-
#include <thrust/system/cuda/detail/cub/block/block_load.cuh>
36-
#include <thrust/system/cuda/detail/cub/block/block_store.cuh>
37-
#include <thrust/system/cuda/detail/cub/block/block_scan.cuh>
35+
#include <cub/block/block_load.cuh>
36+
#include <cub/block/block_store.cuh>
37+
#include <cub/block/block_scan.cuh>
3838

3939
THRUST_BEGIN_NS
4040

@@ -491,6 +491,51 @@ namespace core {
491491
return 0;
492492
}
493493

494+
template <class Kernel>
495+
int CUB_RUNTIME_FUNCTION
496+
get_max_block_size(Kernel k)
497+
{
498+
int devId;
499+
cuda_cub::throw_on_error(cudaGetDevice(&devId),
500+
"get_max_block_size :"
501+
"failed to cudaGetDevice");
502+
503+
cudaOccDeviceProp occ_prop;
504+
cuda_cub::throw_on_error(get_occ_device_properties(occ_prop, devId),
505+
"get_max_block_size: "
506+
"failed to cudaGetDeviceProperties");
507+
508+
509+
cudaFuncAttributes attribs;
510+
cuda_cub::throw_on_error(cudaFuncGetAttributes(&attribs, reinterpret_cast<void *>(k)),
511+
"get_max_block_size: "
512+
"failed to cudaFuncGetAttributes");
513+
cudaOccFuncAttributes occ_attrib(attribs);
514+
515+
516+
cudaFuncCache cacheConfig;
517+
cuda_cub::throw_on_error(cudaDeviceGetCacheConfig(&cacheConfig),
518+
"get_max_block_size: "
519+
"failed to cudaDeviceGetCacheConfig");
520+
521+
cudaOccDeviceState occ_state;
522+
occ_state.cacheConfig = (cudaOccCacheConfig)cacheConfig;
523+
int block_size = 0;
524+
int min_grid_size = 0;
525+
cudaOccError occ_status = cudaOccMaxPotentialOccupancyBlockSize(&min_grid_size,
526+
&block_size,
527+
&occ_prop,
528+
&occ_attrib,
529+
&occ_state,
530+
0);
531+
if (CUDA_OCC_SUCCESS != occ_status || block_size <= 0)
532+
cuda_cub::throw_on_error(cudaErrorInvalidConfiguration,
533+
"get_max_block_size: "
534+
"failed to cudaOccMaxPotentialOccupancyBlockSize");
535+
536+
return block_size;
537+
}
538+
494539
// LoadIterator
495540
// ------------
496541
// if trivial iterator is passed, wrap loads into LDG
@@ -623,7 +668,7 @@ namespace core {
623668
}
624669

625670
#define CUDA_CUB_RET_IF_FAIL(e) \
626-
if (thrust::cuda_cub::cub::Debug((e), __FILE__, __LINE__)) return e;
671+
if (cub::Debug((e), __FILE__, __LINE__)) return e;
627672

628673
// uninitialized
629674
// -------

thrust/system/cuda/detail/malloc_and_free.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include <thrust/memory.h>
2424
#include <thrust/system/cuda/config.h>
2525
#ifdef THRUST_CACHING_DEVICE_MALLOC
26-
#include <thrust/system/cuda/detail/cub/util_allocator.cuh>
26+
#include <cub/util_allocator.cuh>
2727
#endif
2828
#include <thrust/system/cuda/detail/util.h>
2929
#include <thrust/system/detail/bad_alloc.h>

0 commit comments

Comments
 (0)