Skip to content

Commit d285d2a

Browse files
committed
Add a general distributed DBSCAN implementation
1 parent 2784f30 commit d285d2a

File tree

2 files changed

+57
-11
lines changed

2 files changed

+57
-11
lines changed

src/cluster/ArborX_DistributedDBSCAN.hpp

+34-6
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ void dbscan(MPI_Comm comm, ExecutionSpace const &space,
4646

4747
ARBORX_ASSERT(eps > 0);
4848
ARBORX_ASSERT(core_min_size >= 2);
49-
if (core_min_size > 2)
50-
Kokkos::abort("minPts > 2 is not supported yet");
5149

5250
using Point = typename Points::value_type;
5351
static_assert(GeometryTraits::is_point_v<Point>);
5452
static_assert(
5553
std::is_same_v<typename GeometryTraits::coordinate_type<Point>::type,
5654
Coordinate>);
5755

56+
bool const is_special_case = (core_min_size == 2);
57+
5858
Points points{primitives}; // NOLINT
5959
int const n_local = points.size();
6060

@@ -65,8 +65,13 @@ void dbscan(MPI_Comm comm, ExecutionSpace const &space,
6565
Kokkos::View<int *, MemorySpace> ghost_ids(prefix + "ghost_ids", 0);
6666
Kokkos::View<Point *, MemorySpace> ghost_points(prefix + "ghost_points", 0);
6767
Kokkos::View<int *, MemorySpace> ghost_ranks(prefix + "ghost_ranks", 0);
68-
Details::forwardNeighbors(comm, space, points, eps, ghost_points, ghost_ids,
69-
ghost_ranks);
68+
// For minPts=2 case, we only need to fetch the ponts within eps distance.
69+
// For minPts>2, we need points within 2*eps distance to allow to use the
70+
// local DBSCAN algorithm to determine core points.
71+
Details::forwardNeighbors(
72+
comm, space, points,
73+
(is_special_case ? eps : std::nextafter(2 * eps, 10 * eps)), ghost_points,
74+
ghost_ids, ghost_ranks);
7075
int const n_ghost = ghost_points.size();
7176

7277
// Step 2: do local DBSCAN
@@ -142,13 +147,36 @@ void dbscan(MPI_Comm comm, ExecutionSpace const &space,
142147
Kokkos::resize(space, ghost_labels, num_compressed);
143148
Details::communicateNeighborDataBack(comm, space, ghost_ranks, ghost_ids,
144149
ghost_labels);
150+
Kokkos::resize(ghost_ranks, 0); // free space
145151

146152
// Step 5: process multi-labeled indices
147153
Kokkos::View<Details::MergePair *, MemorySpace> local_merge_pairs(
148154
prefix + "local_merge_pairs", 0);
149-
Details::computeMergePairs(space, labels, ghost_ids, ghost_labels,
150-
local_merge_pairs);
155+
if (is_special_case)
156+
{
157+
Details::computeMergePairs(space, Details::CCSCorePoints{}, labels,
158+
ghost_ids, ghost_labels, local_merge_pairs);
159+
}
160+
else
161+
{
162+
BoundingVolumeHierarchy bvh(
163+
space, Details::UnifiedPoints<Points, decltype(ghost_points)>{
164+
points, ghost_points});
165+
166+
Kokkos::View<int *, MemorySpace> num_neigh(prefix + "num_neighbors",
167+
n_local);
168+
bvh.query(space,
169+
Experimental::attach_indices(
170+
Experimental::make_intersects(points, eps)),
171+
Details::CountUpToN<MemorySpace>{num_neigh, core_min_size});
172+
173+
Details::computeMergePairs(
174+
space, Details::DBSCANCorePoints<MemorySpace>{num_neigh, core_min_size},
175+
labels, ghost_ids, ghost_labels, local_merge_pairs);
176+
}
151177
sortAndFilterMergePairs(space, local_merge_pairs);
178+
Kokkos::resize(ghost_ids, 0); // free space
179+
Kokkos::resize(ghost_labels, 0); // free space
152180

153181
// Step 6: communicate merge pairs (all-to-all)
154182
Kokkos::View<Details::MergePair *, MemorySpace> global_merge_pairs(

src/cluster/detail/ArborX_DistributedDBSCANHelpers.hpp

+23-5
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,12 @@ struct MergePair
311311
}
312312
};
313313

314-
template <typename ExecutionSpace, typename Labels, typename ImportedIds,
315-
typename ImportedLabels, typename MergePairs>
316-
void computeMergePairs(ExecutionSpace const &space, Labels &local_labels,
317-
ImportedIds &imported_ids,
318-
ImportedLabels &imported_labels, MergePairs &merge_pairs)
314+
template <typename ExecutionSpace, typename CorePoints, typename Labels,
315+
typename ImportedIds, typename ImportedLabels, typename MergePairs>
316+
void computeMergePairs(ExecutionSpace const &space, CorePoints const &is_core,
317+
Labels &local_labels, ImportedIds const &imported_ids,
318+
ImportedLabels const &imported_labels,
319+
MergePairs &merge_pairs)
319320
{
320321
std::string prefix = "ArborX::DistributedDBSCAN::computeMergePairs";
321322
Kokkos::Profiling::ScopedRegion guard(prefix);
@@ -364,8 +365,25 @@ void computeMergePairs(ExecutionSpace const &space, Labels &local_labels,
364365

365366
int num_valid = (end - begin) + (is_local_valid);
366367
if (num_valid < 2)
368+
{
369+
// A noise point or a point with a single label
367370
return;
371+
}
372+
373+
if (!is_core(id))
374+
{
375+
// A border point with multiple labels
376+
if (is_final && !is_local_valid)
377+
{
378+
// Update local label if it is invalid (all imported labels are
379+
// valid as we filter out noise before communicating)
380+
local_labels(id) = imported_labels(begin);
381+
}
382+
383+
return;
384+
}
368385

386+
// A core point with multiple labels
369387
auto min_label = (is_local_valid ? local_label : LLONG_MAX);
370388
for (int j = begin; j < end; ++j)
371389
{

0 commit comments

Comments
 (0)