Skip to content

Commit

Permalink
Merge pull request #1195 from aprokop/fix_hardcoded_floats_tree_trave…
Browse files Browse the repository at this point in the history
…rsal
  • Loading branch information
aprokop authored Jan 6, 2025
2 parents 12d87bb + 6ccf582 commit aefa0f5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 19 deletions.
6 changes: 4 additions & 2 deletions src/spatial/detail/ArborX_BruteForceImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ struct BruteForceImpl
int const n_indexables = values.size();
int const n_predicates = predicates.size();

NearestBufferProvider<MemorySpace> buffer_provider(space, predicates);
using Coordinate = decltype(predicates(0).distance(indexables(0)));
NearestBufferProvider<MemorySpace, Coordinate> buffer_provider(space,
predicates);

Kokkos::parallel_for(
"ArborX::BruteForce::query::nearest::"
Expand All @@ -168,7 +170,7 @@ struct BruteForceImpl
return;

using PairIndexDistance =
typename NearestBufferProvider<MemorySpace>::PairIndexDistance;
typename decltype(buffer_provider)::PairIndexDistance;
struct CompareDistance
{
KOKKOS_INLINE_FUNCTION bool
Expand Down
14 changes: 6 additions & 8 deletions src/spatial/detail/ArborX_NearestBufferProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
namespace ArborX::Details
{

template <typename MemorySpace>
template <typename MemorySpace, typename Coordinate>
struct NearestBufferProvider
{
static_assert(Kokkos::is_memory_space_v<MemorySpace>);

using PairIndexDistance = Kokkos::pair<int, float>;
using PairIndexDistance = Kokkos::pair<int, Coordinate>;

Kokkos::View<PairIndexDistance *, MemorySpace> _buffer;
Kokkos::View<int *, MemorySpace> _offset;

NearestBufferProvider() = default;
NearestBufferProvider()
: _buffer("ArborX::NearestBufferProvider::buffer", 0)
, _offset("ArborX::NearestBufferProvider::offset", 0)
{}

template <typename ExecutionSpace, typename Predicates>
NearestBufferProvider(ExecutionSpace const &space,
Expand All @@ -46,11 +49,6 @@ struct NearestBufferProvider
Kokkos::make_pair(_offset(i), _offset(i + 1)));
}

// Enclosing function for an extended __host__ __device__ lambda cannot have
// private or protected access within its class
#ifndef KOKKOS_COMPILER_NVCC
private:
#endif
template <typename ExecutionSpace, typename Predicates>
void allocateBuffer(ExecutionSpace const &space, Predicates const &predicates)
{
Expand Down
20 changes: 11 additions & 9 deletions src/spatial/detail/ArborX_TreeTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
Predicates _predicates;
Callback _callback;

NearestBufferProvider<MemorySpace> _buffer;
using Coordinate = decltype(std::declval<Predicates>()(0).distance(
HappyTreeFriends::getIndexable(_bvh, 0)));

NearestBufferProvider<MemorySpace, Coordinate> _buffer;

template <typename ExecutionSpace>
TreeTraversal(ExecutionSpace const &space, BVH const &bvh,
Expand All @@ -151,7 +154,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
}
else
{
_buffer = NearestBufferProvider<MemorySpace>(space, predicates);
_buffer.allocateBuffer(space, predicates);

Kokkos::parallel_for("ArborX::TreeTraversal::nearest",
Kokkos::RangePolicy(space, 0, predicates.size()),
Expand Down Expand Up @@ -184,8 +187,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
if (k < 1)
return;

using PairIndexDistance =
typename NearestBufferProvider<MemorySpace>::PairIndexDistance;
using PairIndexDistance = typename decltype(_buffer)::PairIndexDistance;
struct CompareDistance
{
KOKKOS_INLINE_FUNCTION bool operator()(PairIndexDistance const &lhs,
Expand Down Expand Up @@ -217,7 +219,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
auto *stack_ptr = stack;
*stack_ptr++ = SENTINEL;
#if !defined(__CUDA_ARCH__)
float stack_distance[64];
Coordinate stack_distance[64];
auto *stack_distance_ptr = stack_distance;
*stack_distance_ptr++ = 0.f;
#endif
Expand All @@ -226,14 +228,14 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
int left_child;
int right_child;

float distance_left = 0.f;
float distance_right = 0.f;
float distance_node = 0.f;
Coordinate distance_left = 0;
Coordinate distance_right = 0;
Coordinate distance_node = 0;

// Nodes with a distance that exceed that radius can safely be
// discarded. Initialize the radius to infinity and tighten it once k
// neighbors have been found.
auto radius = KokkosExt::ArithmeticTraits::infinity<float>::value;
auto radius = KokkosExt::ArithmeticTraits::infinity<Coordinate>::value;

do
{
Expand Down

0 comments on commit aefa0f5

Please sign in to comment.