1
1
#include " ../hnswlib/hnswlib.h"
2
2
#include < thread>
3
+
4
+
3
5
class StopW
4
6
{
5
7
std::chrono::steady_clock::time_point time_begin;
@@ -22,6 +24,7 @@ class StopW
22
24
}
23
25
};
24
26
27
+
25
28
/*
26
29
* replacement for the openmp '#pragma omp parallel for' directive
27
30
* only handles a subset of functionality (no reductions etc)
@@ -81,8 +84,6 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
81
84
std::rethrow_exception (lastException);
82
85
}
83
86
}
84
-
85
-
86
87
}
87
88
88
89
@@ -94,7 +95,7 @@ std::vector<datatype> load_batch(std::string path, int size)
94
95
assert (sizeof (datatype) == 4 );
95
96
96
97
std::ifstream file;
97
- file.open (path);
98
+ file.open (path, std::ios::binary );
98
99
if (!file.is_open ())
99
100
{
100
101
std::cout << " Cannot open " << path << " \n " ;
@@ -107,15 +108,14 @@ std::vector<datatype> load_batch(std::string path, int size)
107
108
return batch;
108
109
}
109
110
111
+
110
112
template <typename d_type>
111
113
static float
112
114
test_approx (std::vector<float > &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
113
115
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K)
114
116
{
115
117
size_t correct = 0 ;
116
118
size_t total = 0 ;
117
- // uncomment to test in parallel mode:
118
-
119
119
120
120
for (int i = 0 ; i < qsize; i++)
121
121
{
@@ -137,10 +137,16 @@ test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<
137
137
return 1 .0f * correct / total;
138
138
}
139
139
140
+
140
141
static void
141
- test_vs_recall (std::vector<float > &queries, size_t qsize, hnswlib::HierarchicalNSW<float > &appr_alg, size_t vecdim,
142
- std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
143
- {
142
+ test_vs_recall (
143
+ std::vector<float > &queries,
144
+ size_t qsize,
145
+ hnswlib::HierarchicalNSW<float > &appr_alg,
146
+ size_t vecdim,
147
+ std::vector<std::unordered_set<hnswlib::labeltype>> &answers,
148
+ size_t k) {
149
+
144
150
std::vector<size_t > efs = {1 };
145
151
for (int i = k; i < 30 ; i++)
146
152
{
@@ -155,6 +161,8 @@ test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalN
155
161
efs.push_back (i);
156
162
}
157
163
std::cout << " ef\t recall\t time\t hops\t distcomp\n " ;
164
+
165
+ bool test_passed = false ;
158
166
for (size_t ef : efs)
159
167
{
160
168
appr_alg.setEf (ef);
@@ -171,20 +179,24 @@ test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalN
171
179
std::cout << ef << " \t " << recall << " \t " << time_us_per_query << " us \t " <<hops_per_query<<" \t " <<distance_comp_per_query << " \n " ;
172
180
if (recall > 0.99 )
173
181
{
182
+ test_passed = true ;
174
183
std::cout << " Recall is over 0.99! " <<recall << " \t " << time_us_per_query << " us \t " <<hops_per_query<<" \t " <<distance_comp_per_query << " \n " ;
175
184
break ;
176
185
}
177
186
}
187
+ if (!test_passed)
188
+ {
189
+ std::cerr << " Test failed\n " ;
190
+ exit (1 );
191
+ }
178
192
}
179
193
194
+
180
195
int main (int argc, char **argv)
181
196
{
182
-
183
197
int M = 16 ;
184
198
int efConstruction = 200 ;
185
199
int num_threads = std::thread::hardware_concurrency ();
186
-
187
-
188
200
189
201
bool update = false ;
190
202
@@ -207,7 +219,6 @@ int main(int argc, char **argv)
207
219
208
220
std::string path = " ../examples/data/" ;
209
221
210
-
211
222
int N;
212
223
int dummy_data_multiplier;
213
224
int N_queries;
@@ -216,8 +227,7 @@ int main(int argc, char **argv)
216
227
{
217
228
std::ifstream configfile;
218
229
configfile.open (path + " /config.txt" );
219
- if (!configfile.is_open ())
220
- {
230
+ if (!configfile.is_open ()) {
221
231
std::cout << " Cannot open config.txt\n " ;
222
232
return 1 ;
223
233
}
@@ -237,11 +247,9 @@ int main(int argc, char **argv)
237
247
238
248
StopW stopw = StopW ();
239
249
240
- if (update)
241
- {
250
+ if (update) {
242
251
std::cout << " Update iteration 0\n " ;
243
252
244
-
245
253
ParallelFor (1 , N, num_threads, [&](size_t i, size_t threadId) {
246
254
appr_alg.addPoint ((void *)(dummy_batch.data () + i * d), i);
247
255
});
@@ -252,13 +260,12 @@ int main(int argc, char **argv)
252
260
});
253
261
appr_alg.checkIntegrity ();
254
262
255
- for (int b = 1 ; b < dummy_data_multiplier; b++)
256
- {
263
+ for (int b = 1 ; b < dummy_data_multiplier; b++) {
257
264
std::cout << " Update iteration " << b << " \n " ;
258
265
char cpath[1024 ];
259
266
sprintf (cpath, " batch_dummy_%02d.bin" , b);
260
267
std::vector<float > dummy_batchb = load_batch<float >(path + cpath, N * d);
261
-
268
+
262
269
ParallelFor (0 , N, num_threads, [&](size_t i, size_t threadId) {
263
270
appr_alg.addPoint ((void *)(dummy_batch.data () + i * d), i);
264
271
});
@@ -268,7 +275,7 @@ int main(int argc, char **argv)
268
275
269
276
std::cout << " Inserting final elements\n " ;
270
277
std::vector<float > final_batch = load_batch<float >(path + " batch_final.bin" , N * d);
271
-
278
+
272
279
stopw.reset ();
273
280
ParallelFor (0 , N, num_threads, [&](size_t i, size_t threadId) {
274
281
appr_alg.addPoint ((void *)(final_batch.data () + i * d), i);
@@ -280,19 +287,16 @@ int main(int argc, char **argv)
280
287
std::vector<int > gt = load_batch<int >(path + " gt.bin" , N_queries * K);
281
288
282
289
std::vector<std::unordered_set<hnswlib::labeltype>> answers (N_queries);
283
- for (int i = 0 ; i < N_queries; i++)
284
- {
285
- for (int j = 0 ; j < K; j++)
286
- {
290
+ for (int i = 0 ; i < N_queries; i++) {
291
+ for (int j = 0 ; j < K; j++) {
287
292
answers[i].insert (gt[i * K + j]);
288
293
}
289
294
}
290
295
291
- for (int i = 0 ; i < 3 ; i++)
292
- {
296
+ for (int i = 0 ; i < 3 ; i++) {
293
297
std::cout << " Test iteration " << i << " \n " ;
294
298
test_vs_recall (queries_batch, N_queries, appr_alg, d, answers, K);
295
299
}
296
300
297
301
return 0 ;
298
- };
302
+ };
0 commit comments