|
14 | 14 |
|
15 | 15 | #include <ArborX_DetailsAlgorithms.hpp> // expand
|
16 | 16 | #include <ArborX_DetailsKokkosExtMinMaxOperations.hpp>
|
| 17 | +#include <ArborX_DetailsKokkosExtStdAlgorithms.hpp> |
| 18 | +#include <ArborX_DetailsKokkosExtViewHelpers.hpp> |
| 19 | +#include <ArborX_DetailsNearestBufferProvider.hpp> |
| 20 | +#include <ArborX_DetailsPriorityQueue.hpp> |
17 | 21 | #include <ArborX_Exception.hpp>
|
| 22 | +#include <ArborX_Predicates.hpp> |
18 | 23 |
|
19 | 24 | #include <Kokkos_Core.hpp>
|
| 25 | +#include <Kokkos_Profiling_ScopedRegion.hpp> |
20 | 26 |
|
21 |
| -namespace ArborX |
22 |
| -{ |
23 |
| -namespace Details |
| 27 | +namespace ArborX::Details |
24 | 28 | {
|
25 | 29 | struct BruteForceImpl
|
26 | 30 | {
|
@@ -48,10 +52,12 @@ struct BruteForceImpl
|
48 | 52 |
|
49 | 53 | template <class ExecutionSpace, class Predicates, class Values,
|
50 | 54 | class Indexables, class Callback>
|
51 |
| - static void query(ExecutionSpace const &space, Predicates const &predicates, |
52 |
| - Values const &values, Indexables const &indexables, |
53 |
| - Callback const &callback) |
| 55 | + static void query(SpatialPredicateTag, ExecutionSpace const &space, |
| 56 | + Predicates const &predicates, Values const &values, |
| 57 | + Indexables const &indexables, Callback const &callback) |
54 | 58 | {
|
| 59 | + Kokkos::Profiling::ScopedRegion guard("ArborX::BruteForce::query::spatial"); |
| 60 | + |
55 | 61 | using TeamPolicy = Kokkos::TeamPolicy<ExecutionSpace>;
|
56 | 62 | using PredicateType = typename Predicates::value_type;
|
57 | 63 | using IndexableType = std::decay_t<decltype(indexables(0))>;
|
@@ -136,8 +142,80 @@ struct BruteForceImpl
|
136 | 142 | });
|
137 | 143 | });
|
138 | 144 | }
|
| 145 | + |
| 146 | + template <class ExecutionSpace, class Predicates, class Values, |
| 147 | + class Indexables, class Callback> |
| 148 | + static void query(NearestPredicateTag, ExecutionSpace const &space, |
| 149 | + Predicates const &predicates, Values const &values, |
| 150 | + Indexables const &indexables, Callback const &callback) |
| 151 | + { |
| 152 | + Kokkos::Profiling::ScopedRegion guard("ArborX::BruteForce::query::nearest"); |
| 153 | + |
| 154 | + using MemorySpace = typename Values::memory_space; |
| 155 | + |
| 156 | + int const n_indexables = values.size(); |
| 157 | + int const n_predicates = predicates.size(); |
| 158 | + |
| 159 | + NearestBufferProvider<MemorySpace> buffer_provider(space, predicates); |
| 160 | + |
| 161 | + Kokkos::parallel_for( |
| 162 | + "ArborX::BruteForce::query::nearest::" |
| 163 | + "check_all_predicates_against_all_indexables", |
| 164 | + Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_predicates), |
| 165 | + KOKKOS_LAMBDA(int i) { |
| 166 | + auto const &predicate = predicates(i); |
| 167 | + auto const k = getK(predicate); |
| 168 | + auto const buffer = buffer_provider(i); |
| 169 | + |
| 170 | + if (k < 1) |
| 171 | + return; |
| 172 | + |
| 173 | + using PairIndexDistance = |
| 174 | + typename NearestBufferProvider<MemorySpace>::PairIndexDistance; |
| 175 | + struct CompareDistance |
| 176 | + { |
| 177 | + KOKKOS_INLINE_FUNCTION bool |
| 178 | + operator()(PairIndexDistance const &lhs, |
| 179 | + PairIndexDistance const &rhs) const |
| 180 | + { |
| 181 | + return lhs.second < rhs.second; |
| 182 | + } |
| 183 | + }; |
| 184 | + |
| 185 | + PriorityQueue<PairIndexDistance, CompareDistance, |
| 186 | + UnmanagedStaticVector<PairIndexDistance>> |
| 187 | + heap(UnmanagedStaticVector<PairIndexDistance>(buffer.data(), |
| 188 | + buffer.size())); |
| 189 | + |
| 190 | + // Nodes with a distance that exceed that radius can safely be |
| 191 | + // discarded. Initialize the radius to infinity and tighten it once k |
| 192 | + // neighbors have been found. |
| 193 | + auto radius = KokkosExt::ArithmeticTraits::infinity<float>::value; |
| 194 | + |
| 195 | + int j = 0; |
| 196 | + for (; j < n_indexables && j < k; ++j) |
| 197 | + { |
| 198 | + auto const distance = predicate.distance(indexables(j)); |
| 199 | + heap.push(Kokkos::make_pair(j, distance)); |
| 200 | + } |
| 201 | + for (; j < n_indexables; ++j) |
| 202 | + { |
| 203 | + auto const distance = predicate.distance(indexables(j)); |
| 204 | + if (distance < radius) |
| 205 | + { |
| 206 | + heap.popPush(Kokkos::make_pair(j, distance)); |
| 207 | + radius = heap.top().second; |
| 208 | + } |
| 209 | + } |
| 210 | + |
| 211 | + // Match the logic in TreeTraversal and do the sorting |
| 212 | + sortHeap(heap.data(), heap.data() + heap.size(), heap.valueComp()); |
| 213 | + for (decltype(heap.size()) i = 0; i < heap.size(); ++i) |
| 214 | + callback(predicate, values((heap.data() + i)->first)); |
| 215 | + }); |
| 216 | + } |
139 | 217 | };
|
140 |
| -} // namespace Details |
141 |
| -} // namespace ArborX |
| 218 | + |
| 219 | +} // namespace ArborX::Details |
142 | 220 |
|
143 | 221 | #endif
|
0 commit comments