Skip to content

Commit 6d28ec0

Browse files
authored
Refactoring (#410)
1 parent 1f49ffe commit 6d28ec0

14 files changed

+2320
-2352
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ hnswlib.cpython*.so
88
var/
99
.idea/
1010
.vscode/
11-
11+
.vs/

examples/searchKnnCloserFirst_test.cpp

+4-6
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
#include <vector>
1111
#include <iostream>
1212

13-
namespace
14-
{
13+
namespace {
1514

1615
using idx_t = hnswlib::labeltype;
1716

@@ -20,7 +19,7 @@ void test() {
2019
idx_t n = 100;
2120
idx_t nq = 10;
2221
size_t k = 10;
23-
22+
2423
std::vector<float> data(n * d);
2524
std::vector<float> query(nq * d);
2625

@@ -34,7 +33,6 @@ void test() {
3433
for (idx_t i = 0; i < nq * d; ++i) {
3534
query[i] = distrib(rng);
3635
}
37-
3836

3937
hnswlib::L2Space space(d);
4038
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
@@ -68,12 +66,12 @@ void test() {
6866
gd.pop();
6967
}
7068
}
71-
69+
7270
delete alg_brute;
7371
delete alg_hnsw;
7472
}
7573

76-
} // namespace
74+
} // namespace
7775

7876
int main() {
7977
std::cout << "Testing ..." << std::endl;

examples/searchKnnWithFilter_test.cpp

+9-10
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
#include <vector>
88
#include <iostream>
99

10-
namespace
11-
{
10+
namespace {
1211

1312
using idx_t = hnswlib::labeltype;
1413

@@ -30,7 +29,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
3029
idx_t n = 100;
3130
idx_t nq = 10;
3231
size_t k = 10;
33-
32+
3433
std::vector<float> data(n * d);
3534
std::vector<float> query(nq * d);
3635

@@ -46,8 +45,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
4645
}
4746

4847
hnswlib::L2Space space(d);
49-
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float,filter_func_t>(&space, 2 * n);
50-
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float,filter_func_t>(&space, 2 * n);
48+
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float, filter_func_t>(&space, 2 * n);
49+
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float, filter_func_t>(&space, 2 * n);
5150

5251
for (size_t i = 0; i < n; ++i) {
5352
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
@@ -82,7 +81,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
8281
gd.pop();
8382
}
8483
}
85-
84+
8685
delete alg_brute;
8786
delete alg_hnsw;
8887
}
@@ -109,8 +108,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
109108
}
110109

111110
hnswlib::L2Space space(d);
112-
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float,filter_func_t>(&space, 2 * n);
113-
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float,filter_func_t>(&space, 2 * n);
111+
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float, filter_func_t>(&space, 2 * n);
112+
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float, filter_func_t>(&space, 2 * n);
114113

115114
for (size_t i = 0; i < n; ++i) {
116115
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
@@ -140,12 +139,12 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
140139
delete alg_hnsw;
141140
}
142141

143-
} // namespace
142+
} // namespace
144143

145144
class CustomFilterFunctor: public hnswlib::FilterFunctor {
146145
std::unordered_set<unsigned int> allowed_values;
147146

148-
public:
147+
public:
149148
explicit CustomFilterFunctor(const std::unordered_set<unsigned int>& values) : allowed_values(values) {}
150149

151150
bool operator()(unsigned int id) {

examples/updates_test.cpp

+49-76
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,20 @@
22
#include <thread>
33

44

5-
class StopW
6-
{
5+
class StopW {
76
std::chrono::steady_clock::time_point time_begin;
87

9-
public:
10-
StopW()
11-
{
8+
public:
9+
StopW() {
1210
time_begin = std::chrono::steady_clock::now();
1311
}
1412

15-
float getElapsedTimeMicro()
16-
{
13+
float getElapsedTimeMicro() {
1714
std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now();
1815
return (std::chrono::duration_cast<std::chrono::microseconds>(time_end - time_begin).count());
1916
}
2017

21-
void reset()
22-
{
18+
void reset() {
2319
time_begin = std::chrono::steady_clock::now();
2420
}
2521
};
@@ -88,16 +84,14 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
8884

8985

9086
template <typename datatype>
91-
std::vector<datatype> load_batch(std::string path, int size)
92-
{
87+
std::vector<datatype> load_batch(std::string path, int size) {
9388
std::cout << "Loading " << path << "...";
9489
// float or int32 (python)
9590
assert(sizeof(datatype) == 4);
9691

9792
std::ifstream file;
9893
file.open(path, std::ios::binary);
99-
if (!file.is_open())
100-
{
94+
if (!file.is_open()) {
10195
std::cout << "Cannot open " << path << "\n";
10296
exit(1);
10397
}
@@ -112,26 +106,17 @@ std::vector<datatype> load_batch(std::string path, int size)
112106
template <typename d_type>
113107
static float
114108
test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
115-
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K)
116-
{
109+
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K) {
117110
size_t correct = 0;
118111
size_t total = 0;
119-
//uncomment to test in parallel mode:
120-
121-
122-
for (int i = 0; i < qsize; i++)
123-
{
124112

113+
for (int i = 0; i < qsize; i++) {
125114
std::priority_queue<std::pair<d_type, hnswlib::labeltype>> result = appr_alg.searchKnn((char *)(queries.data() + vecdim * i), K);
126115
total += K;
127-
while (result.size())
128-
{
129-
if (answers[i].find(result.top().second) != answers[i].end())
130-
{
116+
while (result.size()) {
117+
if (answers[i].find(result.top().second) != answers[i].end()) {
131118
correct++;
132-
}
133-
else
134-
{
119+
} else {
135120
}
136121
result.pop();
137122
}
@@ -141,76 +126,70 @@ test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<
141126

142127

143128
static void
144-
test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<float> &appr_alg, size_t vecdim,
145-
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
146-
{
129+
test_vs_recall(
130+
std::vector<float> &queries,
131+
size_t qsize,
132+
hnswlib::HierarchicalNSW<float> &appr_alg,
133+
size_t vecdim,
134+
std::vector<std::unordered_set<hnswlib::labeltype>> &answers,
135+
size_t k) {
136+
147137
std::vector<size_t> efs = {1};
148-
for (int i = k; i < 30; i++)
149-
{
138+
for (int i = k; i < 30; i++) {
150139
efs.push_back(i);
151140
}
152-
for (int i = 30; i < 400; i+=10)
153-
{
141+
for (int i = 30; i < 400; i+=10) {
154142
efs.push_back(i);
155143
}
156-
for (int i = 1000; i < 100000; i += 5000)
157-
{
144+
for (int i = 1000; i < 100000; i += 5000) {
158145
efs.push_back(i);
159146
}
160147
std::cout << "ef\trecall\ttime\thops\tdistcomp\n";
161148

162149
bool test_passed = false;
163-
for (size_t ef : efs)
164-
{
150+
for (size_t ef : efs) {
165151
appr_alg.setEf(ef);
166152

167-
appr_alg.metric_hops=0;
168-
appr_alg.metric_distance_computations=0;
153+
appr_alg.metric_hops = 0;
154+
appr_alg.metric_distance_computations = 0;
169155
StopW stopw = StopW();
170156

171157
float recall = test_approx<float>(queries, qsize, appr_alg, vecdim, answers, k);
172158
float time_us_per_query = stopw.getElapsedTimeMicro() / qsize;
173159
float distance_comp_per_query = appr_alg.metric_distance_computations / (1.0f * qsize);
174160
float hops_per_query = appr_alg.metric_hops / (1.0f * qsize);
175161

176-
std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
177-
if (recall > 0.99)
178-
{
162+
std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t" << hops_per_query << "\t" << distance_comp_per_query << "\n";
163+
if (recall > 0.99) {
179164
test_passed = true;
180-
std::cout << "Recall is over 0.99! "<<recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
165+
std::cout << "Recall is over 0.99! " << recall << "\t" << time_us_per_query << "us \t" << hops_per_query << "\t" << distance_comp_per_query << "\n";
181166
break;
182167
}
183168
}
184-
if (!test_passed)
185-
{
169+
if (!test_passed) {
186170
std::cerr << "Test failed\n";
187171
exit(1);
188172
}
189173
}
190174

191175

192-
int main(int argc, char **argv)
193-
{
176+
int main(int argc, char **argv) {
194177
int M = 16;
195178
int efConstruction = 200;
196179
int num_threads = std::thread::hardware_concurrency();
197180

198181
bool update = false;
199182

200-
if (argc == 2)
201-
{
202-
if (std::string(argv[1]) == "update")
203-
{
183+
if (argc == 2) {
184+
if (std::string(argv[1]) == "update") {
204185
update = true;
205186
std::cout << "Updates are on\n";
206-
}
207-
else {
208-
std::cout<<"Usage ./test_updates [update]\n";
187+
} else {
188+
std::cout << "Usage ./test_updates [update]\n";
209189
exit(1);
210190
}
211-
}
212-
else if (argc>2){
213-
std::cout<<"Usage ./test_updates [update]\n";
191+
} else if (argc > 2) {
192+
std::cout << "Usage ./test_updates [update]\n";
214193
exit(1);
215194
}
216195

@@ -224,8 +203,7 @@ int main(int argc, char **argv)
224203
{
225204
std::ifstream configfile;
226205
configfile.open(path + "/config.txt");
227-
if (!configfile.is_open())
228-
{
206+
if (!configfile.is_open()) {
229207
std::cout << "Cannot open config.txt\n";
230208
return 1;
231209
}
@@ -245,10 +223,9 @@ int main(int argc, char **argv)
245223

246224
StopW stopw = StopW();
247225

248-
if (update)
249-
{
226+
if (update) {
250227
std::cout << "Update iteration 0\n";
251-
228+
252229
ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) {
253230
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
254231
});
@@ -259,14 +236,13 @@ int main(int argc, char **argv)
259236
});
260237
appr_alg.checkIntegrity();
261238

262-
for (int b = 1; b < dummy_data_multiplier; b++)
263-
{
239+
for (int b = 1; b < dummy_data_multiplier; b++) {
264240
std::cout << "Update iteration " << b << "\n";
265241
char cpath[1024];
266242
sprintf(cpath, "batch_dummy_%02d.bin", b);
267243
std::vector<float> dummy_batchb = load_batch<float>(path + cpath, N * d);
268-
269-
ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
244+
245+
ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
270246
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
271247
});
272248
appr_alg.checkIntegrity();
@@ -275,31 +251,28 @@ int main(int argc, char **argv)
275251

276252
std::cout << "Inserting final elements\n";
277253
std::vector<float> final_batch = load_batch<float>(path + "batch_final.bin", N * d);
278-
254+
279255
stopw.reset();
280256
ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
281257
appr_alg.addPoint((void *)(final_batch.data() + i * d), i);
282258
});
283-
std::cout<<"Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n";
259+
std::cout << "Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n";
284260
std::cout << "Running tests\n";
285261
std::vector<float> queries_batch = load_batch<float>(path + "queries.bin", N_queries * d);
286262

287263
std::vector<int> gt = load_batch<int>(path + "gt.bin", N_queries * K);
288264

289265
std::vector<std::unordered_set<hnswlib::labeltype>> answers(N_queries);
290-
for (int i = 0; i < N_queries; i++)
291-
{
292-
for (int j = 0; j < K; j++)
293-
{
266+
for (int i = 0; i < N_queries; i++) {
267+
for (int j = 0; j < K; j++) {
294268
answers[i].insert(gt[i * K + j]);
295269
}
296270
}
297271

298-
for (int i = 0; i < 3; i++)
299-
{
272+
for (int i = 0; i < 3; i++) {
300273
std::cout << "Test iteration " << i << "\n";
301274
test_vs_recall(queries_batch, N_queries, appr_alg, d, answers, K);
302275
}
303276

304277
return 0;
305-
};
278+
}

0 commit comments

Comments
 (0)