Skip to content

Commit d48fe3d

Browse files
authored
Merge pull request #1053 from aprokop/brute_force_nearest
Implement nearest query for BruteForce
2 parents a850ae8 + 7e95982 commit d48fe3d

7 files changed

+177
-70
lines changed

src/ArborX_BruteForce.hpp

+1-7
Original file line numberDiff line numberDiff line change
@@ -237,18 +237,12 @@ void BruteForce<MemorySpace, Value, IndexableGetter, BoundingVolume>::query(
237237
Predicates predicates{user_predicates}; // NOLINT
238238

239239
using Tag = typename Predicates::value_type::Tag;
240-
static_assert(std::is_same<Tag, Details::SpatialPredicateTag>{},
241-
"nearest query not implemented yet");
242-
243-
Kokkos::Profiling::pushRegion("ArborX::BruteForce::query::spatial");
244240

245241
Details::BruteForceImpl::query(
246-
space, predicates, _values,
242+
Tag{}, space, predicates, _values,
247243
Details::Indexables<decltype(_values), IndexableGetter>{
248244
_values, _indexable_getter},
249245
callback);
250-
251-
Kokkos::Profiling::popRegion();
252246
}
253247

254248
} // namespace ArborX

src/details/ArborX_DetailsBruteForceImpl.hpp

+86-8
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414

1515
#include <ArborX_DetailsAlgorithms.hpp> // expand
1616
#include <ArborX_DetailsKokkosExtMinMaxOperations.hpp>
17+
#include <ArborX_DetailsKokkosExtStdAlgorithms.hpp>
18+
#include <ArborX_DetailsKokkosExtViewHelpers.hpp>
19+
#include <ArborX_DetailsNearestBufferProvider.hpp>
20+
#include <ArborX_DetailsPriorityQueue.hpp>
1721
#include <ArborX_Exception.hpp>
22+
#include <ArborX_Predicates.hpp>
1823

1924
#include <Kokkos_Core.hpp>
25+
#include <Kokkos_Profiling_ScopedRegion.hpp>
2026

21-
namespace ArborX
22-
{
23-
namespace Details
27+
namespace ArborX::Details
2428
{
2529
struct BruteForceImpl
2630
{
@@ -48,10 +52,12 @@ struct BruteForceImpl
4852

4953
template <class ExecutionSpace, class Predicates, class Values,
5054
class Indexables, class Callback>
51-
static void query(ExecutionSpace const &space, Predicates const &predicates,
52-
Values const &values, Indexables const &indexables,
53-
Callback const &callback)
55+
static void query(SpatialPredicateTag, ExecutionSpace const &space,
56+
Predicates const &predicates, Values const &values,
57+
Indexables const &indexables, Callback const &callback)
5458
{
59+
Kokkos::Profiling::ScopedRegion guard("ArborX::BruteForce::query::spatial");
60+
5561
using TeamPolicy = Kokkos::TeamPolicy<ExecutionSpace>;
5662
using PredicateType = typename Predicates::value_type;
5763
using IndexableType = std::decay_t<decltype(indexables(0))>;
@@ -136,8 +142,80 @@ struct BruteForceImpl
136142
});
137143
});
138144
}
145+
146+
template <class ExecutionSpace, class Predicates, class Values,
147+
class Indexables, class Callback>
148+
static void query(NearestPredicateTag, ExecutionSpace const &space,
149+
Predicates const &predicates, Values const &values,
150+
Indexables const &indexables, Callback const &callback)
151+
{
152+
Kokkos::Profiling::ScopedRegion guard("ArborX::BruteForce::query::nearest");
153+
154+
using MemorySpace = typename Values::memory_space;
155+
156+
int const n_indexables = values.size();
157+
int const n_predicates = predicates.size();
158+
159+
NearestBufferProvider<MemorySpace> buffer_provider(space, predicates);
160+
161+
Kokkos::parallel_for(
162+
"ArborX::BruteForce::query::nearest::"
163+
"check_all_predicates_against_all_indexables",
164+
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_predicates),
165+
KOKKOS_LAMBDA(int i) {
166+
auto const &predicate = predicates(i);
167+
auto const k = getK(predicate);
168+
auto const buffer = buffer_provider(i);
169+
170+
if (k < 1)
171+
return;
172+
173+
using PairIndexDistance =
174+
typename NearestBufferProvider<MemorySpace>::PairIndexDistance;
175+
struct CompareDistance
176+
{
177+
KOKKOS_INLINE_FUNCTION bool
178+
operator()(PairIndexDistance const &lhs,
179+
PairIndexDistance const &rhs) const
180+
{
181+
return lhs.second < rhs.second;
182+
}
183+
};
184+
185+
PriorityQueue<PairIndexDistance, CompareDistance,
186+
UnmanagedStaticVector<PairIndexDistance>>
187+
heap(UnmanagedStaticVector<PairIndexDistance>(buffer.data(),
188+
buffer.size()));
189+
190+
// Nodes with a distance that exceed that radius can safely be
191+
// discarded. Initialize the radius to infinity and tighten it once k
192+
// neighbors have been found.
193+
auto radius = KokkosExt::ArithmeticTraits::infinity<float>::value;
194+
195+
int j = 0;
196+
for (; j < n_indexables && j < k; ++j)
197+
{
198+
auto const distance = predicate.distance(indexables(j));
199+
heap.push(Kokkos::make_pair(j, distance));
200+
}
201+
for (; j < n_indexables; ++j)
202+
{
203+
auto const distance = predicate.distance(indexables(j));
204+
if (distance < radius)
205+
{
206+
heap.popPush(Kokkos::make_pair(j, distance));
207+
radius = heap.top().second;
208+
}
209+
}
210+
211+
// Match the logic in TreeTraversal and do the sorting
212+
sortHeap(heap.data(), heap.data() + heap.size(), heap.valueComp());
213+
for (decltype(heap.size()) i = 0; i < heap.size(); ++i)
214+
callback(predicate, values((heap.data() + i)->first));
215+
});
216+
}
139217
};
140-
} // namespace Details
141-
} // namespace ArborX
218+
219+
} // namespace ArborX::Details
142220

143221
#endif
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/****************************************************************************
2+
* Copyright (c) 2017-2023 by the ArborX authors *
3+
* All rights reserved. *
4+
* *
5+
* This file is part of the ArborX library. ArborX is *
6+
* distributed under a BSD 3-clause license. For the licensing terms see *
7+
* the LICENSE file in the top-level directory. *
8+
* *
9+
* SPDX-License-Identifier: BSD-3-Clause *
10+
****************************************************************************/
11+
#ifndef ARBORX_DETAILS_NEAREST_BUFFER_PROVIDER_HPP
12+
#define ARBORX_DETAILS_NEAREST_BUFFER_PROVIDER_HPP
13+
14+
#include <ArborX_DetailsKokkosExtStdAlgorithms.hpp>
15+
#include <ArborX_DetailsKokkosExtViewHelpers.hpp>
16+
17+
#include <Kokkos_Core.hpp>
18+
19+
namespace ArborX::Details
20+
{
21+
22+
template <typename MemorySpace>
23+
struct NearestBufferProvider
24+
{
25+
static_assert(Kokkos::is_memory_space_v<MemorySpace>);
26+
27+
using PairIndexDistance = Kokkos::pair<int, float>;
28+
29+
Kokkos::View<PairIndexDistance *, MemorySpace> _buffer;
30+
Kokkos::View<int *, MemorySpace> _offset;
31+
32+
NearestBufferProvider() = default;
33+
34+
template <typename ExecutionSpace, typename Predicates>
35+
NearestBufferProvider(ExecutionSpace const &space,
36+
Predicates const &predicates)
37+
: _buffer("ArborX::NearestBufferProvider::buffer", 0)
38+
, _offset("ArborX::NearestBufferProvider::offset", 0)
39+
{
40+
allocateBuffer(space, predicates);
41+
}
42+
43+
KOKKOS_FUNCTION auto operator()(int i) const
44+
{
45+
return Kokkos::subview(_buffer,
46+
Kokkos::make_pair(_offset(i), _offset(i + 1)));
47+
}
48+
49+
// Enclosing function for an extended __host__ __device__ lambda cannot have
50+
// private or protected access within its class
51+
#ifndef KOKKOS_COMPILER_NVCC
52+
private:
53+
#endif
54+
template <typename ExecutionSpace, typename Predicates>
55+
void allocateBuffer(ExecutionSpace const &space, Predicates const &predicates)
56+
{
57+
auto const n_queries = predicates.size();
58+
59+
KokkosExt::reallocWithoutInitializing(space, _offset, n_queries + 1);
60+
61+
Kokkos::parallel_for(
62+
"ArborX::NearestBufferProvider::scan_queries_for_numbers_of_neighbors",
63+
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_queries),
64+
KOKKOS_CLASS_LAMBDA(int i) { _offset(i) = getK(predicates(i)); });
65+
KokkosExt::exclusive_scan(space, _offset, _offset, 0);
66+
int const buffer_size = KokkosExt::lastElement(space, _offset);
67+
// Allocate buffer over which to perform heap operations in the nearest
68+
// query to store nearest nodes found so far.
69+
// It is not possible to anticipate how much memory to allocate since the
70+
// number of nearest neighbors k is only known at runtime.
71+
72+
KokkosExt::reallocWithoutInitializing(space, _buffer, buffer_size);
73+
}
74+
};
75+
76+
} // namespace ArborX::Details
77+
78+
#endif

src/details/ArborX_DetailsTreeTraversal.hpp

+10-54
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <ArborX_DetailsKokkosExtArithmeticTraits.hpp>
1717
#include <ArborX_DetailsKokkosExtStdAlgorithms.hpp>
1818
#include <ArborX_DetailsKokkosExtViewHelpers.hpp>
19+
#include <ArborX_DetailsNearestBufferProvider.hpp>
1920
#include <ArborX_DetailsNode.hpp> // ROPE_SENTINEL
2021
#include <ArborX_DetailsPriorityQueue.hpp>
2122
#include <ArborX_DetailsStack.hpp>
@@ -128,48 +129,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
128129
Predicates _predicates;
129130
Callback _callback;
130131

131-
using Buffer = Kokkos::View<Kokkos::pair<int, float> *, MemorySpace>;
132-
using Offset = Kokkos::View<int *, MemorySpace>;
133-
struct BufferProvider
134-
{
135-
Buffer _buffer;
136-
Offset _offset;
137-
138-
KOKKOS_FUNCTION auto operator()(int i) const
139-
{
140-
auto const *offset_ptr = &_offset(i);
141-
return Kokkos::subview(_buffer,
142-
Kokkos::make_pair(*offset_ptr, *(offset_ptr + 1)));
143-
}
144-
};
145-
146-
BufferProvider _buffer;
147-
148-
template <typename ExecutionSpace>
149-
void allocateBuffer(ExecutionSpace const &space)
150-
{
151-
auto const n_queries = _predicates.size();
152-
153-
Offset offset(Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
154-
"ArborX::TreeTraversal::nearest::offset"),
155-
n_queries + 1);
156-
Kokkos::parallel_for(
157-
"ArborX::TreeTraversal::nearest::"
158-
"scan_queries_for_numbers_of_neighbors",
159-
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_queries),
160-
KOKKOS_CLASS_LAMBDA(int i) { offset(i) = getK(_predicates(i)); });
161-
KokkosExt::exclusive_scan(space, offset, offset, 0);
162-
int const buffer_size = KokkosExt::lastElement(space, offset);
163-
// Allocate buffer over which to perform heap operations in
164-
// TreeTraversal::nearestQuery() to store nearest leaf nodes found so far.
165-
// It is not possible to anticipate how much memory to allocate since the
166-
// number of nearest neighbors k is only known at runtime.
167-
168-
Buffer buffer(Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
169-
"ArborX::TreeTraversal::nearest::buffer"),
170-
buffer_size);
171-
_buffer = BufferProvider{buffer, offset};
172-
}
132+
NearestBufferProvider<MemorySpace> _buffer;
173133

174134
template <typename ExecutionSpace>
175135
TreeTraversal(ExecutionSpace const &space, BVH const &bvh,
@@ -192,7 +152,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
192152
}
193153
else
194154
{
195-
allocateBuffer(space);
155+
_buffer = NearestBufferProvider<MemorySpace>(space, predicates);
196156

197157
Kokkos::parallel_for(
198158
"ArborX::TreeTraversal::nearest",
@@ -226,17 +186,8 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
226186
if (k < 1)
227187
return;
228188

229-
// Nodes with a distance that exceed that radius can safely be
230-
// discarded. Initialize the radius to infinity and tighten it once k
231-
// neighbors have been found.
232-
auto radius = KokkosExt::ArithmeticTraits::infinity<float>::value;
233-
234-
using PairIndexDistance = Kokkos::pair<int, float>;
235-
static_assert(
236-
std::is_same<typename decltype(buffer)::value_type,
237-
PairIndexDistance>::value,
238-
"Type of the elements stored in the buffer passed as argument to "
239-
"TreeTraversal::nearestQuery is not right");
189+
using PairIndexDistance =
190+
typename NearestBufferProvider<MemorySpace>::PairIndexDistance;
240191
struct CompareDistance
241192
{
242193
KOKKOS_INLINE_FUNCTION bool operator()(PairIndexDistance const &lhs,
@@ -281,6 +232,11 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
281232
float distance_right = 0.f;
282233
float distance_node = 0.f;
283234

235+
// Nodes with a distance that exceed that radius can safely be
236+
// discarded. Initialize the radius to infinity and tighten it once k
237+
// neighbors have been found.
238+
auto radius = KokkosExt::ArithmeticTraits::infinity<float>::value;
239+
284240
do
285241
{
286242
bool traverse_left = false;

test/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ foreach(_test Callbacks Degenerate ManufacturedSolution ComparisonWithBoost)
106106
" ArborX::Details::DefaultIndexableGetter, ArborX::Box>>;\n"
107107
"#define ARBORX_TEST_TREE_TYPES Tuple<ArborX_BruteForce_Box, ArborX_Legacy_BruteForce_Box>\n"
108108
"#define ARBORX_TEST_DEVICE_TYPES std::tuple<${ARBORX_DEVICE_TYPES}>\n"
109-
"#define ARBORX_TEST_DISABLE_NEAREST_QUERY\n"
110109
"#define ARBORX_TEST_DISABLE_CALLBACK_EARLY_EXIT\n"
111110
"#include <tstQueryTree${_test}.cpp>\n"
112111
)

test/tstKokkosToolsAnnotations.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(bvh_query_allocations_prefixed, DeviceType,
8989
void const * /*ptr*/, uint64_t /*size*/) {
9090
std::regex re("^(Testing::"
9191
"|ArborX::BVH::query::"
92+
"|ArborX::NearestBufferProvider::"
9293
"|ArborX::TreeTraversal::spatial::"
9394
"|ArborX::TreeTraversal::nearest::"
9495
"|ArborX::CrsGraphWrapper::"

test/tstKokkosToolsDistributedAnnotations.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(
7373
"|ArborX::DistributedTree::query::"
7474
"|ArborX::Distributor::"
7575
"|ArborX::BVH::query::"
76+
"|ArborX::NearestBufferProvider::"
7677
"|ArborX::TreeTraversal::spatial::"
7778
"|ArborX::TreeTraversal::nearest::"
7879
"|ArborX::CrsGraphWrapper::"

0 commit comments

Comments
 (0)