2
2
#include < thread>
3
3
4
4
5
- class StopW
6
- {
5
+ class StopW {
7
6
std::chrono::steady_clock::time_point time_begin;
8
7
9
- public:
10
- StopW ()
11
- {
8
+ public:
9
+ StopW () {
12
10
time_begin = std::chrono::steady_clock::now ();
13
11
}
14
12
15
- float getElapsedTimeMicro ()
16
- {
13
+ float getElapsedTimeMicro () {
17
14
std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now ();
18
15
return (std::chrono::duration_cast<std::chrono::microseconds>(time_end - time_begin).count ());
19
16
}
20
17
21
- void reset ()
22
- {
18
+ void reset () {
23
19
time_begin = std::chrono::steady_clock::now ();
24
20
}
25
21
};
@@ -88,16 +84,14 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
88
84
89
85
90
86
template <typename datatype>
91
- std::vector<datatype> load_batch (std::string path, int size)
92
- {
87
+ std::vector<datatype> load_batch (std::string path, int size) {
93
88
std::cout << " Loading " << path << " ..." ;
94
89
// float or int32 (python)
95
90
assert (sizeof (datatype) == 4 );
96
91
97
92
std::ifstream file;
98
93
file.open (path, std::ios::binary);
99
- if (!file.is_open ())
100
- {
94
+ if (!file.is_open ()) {
101
95
std::cout << " Cannot open " << path << " \n " ;
102
96
exit (1 );
103
97
}
@@ -112,26 +106,17 @@ std::vector<datatype> load_batch(std::string path, int size)
112
106
template <typename d_type>
113
107
static float
114
108
test_approx (std::vector<float > &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
115
- std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K)
116
- {
109
+ std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K) {
117
110
size_t correct = 0 ;
118
111
size_t total = 0 ;
119
- // uncomment to test in parallel mode:
120
-
121
-
122
- for (int i = 0 ; i < qsize; i++)
123
- {
124
112
113
+ for (int i = 0 ; i < qsize; i++) {
125
114
std::priority_queue<std::pair<d_type, hnswlib::labeltype>> result = appr_alg.searchKnn ((char *)(queries.data () + vecdim * i), K);
126
115
total += K;
127
- while (result.size ())
128
- {
129
- if (answers[i].find (result.top ().second ) != answers[i].end ())
130
- {
116
+ while (result.size ()) {
117
+ if (answers[i].find (result.top ().second ) != answers[i].end ()) {
131
118
correct++;
132
- }
133
- else
134
- {
119
+ } else {
135
120
}
136
121
result.pop ();
137
122
}
@@ -141,76 +126,70 @@ test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<
141
126
142
127
143
128
static void
144
- test_vs_recall (std::vector<float > &queries, size_t qsize, hnswlib::HierarchicalNSW<float > &appr_alg, size_t vecdim,
145
- std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
146
- {
129
+ test_vs_recall (
130
+ std::vector<float > &queries,
131
+ size_t qsize,
132
+ hnswlib::HierarchicalNSW<float > &appr_alg,
133
+ size_t vecdim,
134
+ std::vector<std::unordered_set<hnswlib::labeltype>> &answers,
135
+ size_t k) {
136
+
147
137
std::vector<size_t > efs = {1 };
148
- for (int i = k; i < 30 ; i++)
149
- {
138
+ for (int i = k; i < 30 ; i++) {
150
139
efs.push_back (i);
151
140
}
152
- for (int i = 30 ; i < 400 ; i+=10 )
153
- {
141
+ for (int i = 30 ; i < 400 ; i+=10 ) {
154
142
efs.push_back (i);
155
143
}
156
- for (int i = 1000 ; i < 100000 ; i += 5000 )
157
- {
144
+ for (int i = 1000 ; i < 100000 ; i += 5000 ) {
158
145
efs.push_back (i);
159
146
}
160
147
std::cout << " ef\t recall\t time\t hops\t distcomp\n " ;
161
148
162
149
bool test_passed = false ;
163
- for (size_t ef : efs)
164
- {
150
+ for (size_t ef : efs) {
165
151
appr_alg.setEf (ef);
166
152
167
- appr_alg.metric_hops = 0 ;
168
- appr_alg.metric_distance_computations = 0 ;
153
+ appr_alg.metric_hops = 0 ;
154
+ appr_alg.metric_distance_computations = 0 ;
169
155
StopW stopw = StopW ();
170
156
171
157
float recall = test_approx<float >(queries, qsize, appr_alg, vecdim, answers, k);
172
158
float time_us_per_query = stopw.getElapsedTimeMicro () / qsize;
173
159
float distance_comp_per_query = appr_alg.metric_distance_computations / (1 .0f * qsize);
174
160
float hops_per_query = appr_alg.metric_hops / (1 .0f * qsize);
175
161
176
- std::cout << ef << " \t " << recall << " \t " << time_us_per_query << " us \t " <<hops_per_query<<" \t " <<distance_comp_per_query << " \n " ;
177
- if (recall > 0.99 )
178
- {
162
+ std::cout << ef << " \t " << recall << " \t " << time_us_per_query << " us \t " << hops_per_query << " \t " << distance_comp_per_query << " \n " ;
163
+ if (recall > 0.99 ) {
179
164
test_passed = true ;
180
- std::cout << " Recall is over 0.99! " << recall << " \t " << time_us_per_query << " us \t " << hops_per_query<< " \t " << distance_comp_per_query << " \n " ;
165
+ std::cout << " Recall is over 0.99! " << recall << " \t " << time_us_per_query << " us \t " << hops_per_query << " \t " << distance_comp_per_query << " \n " ;
181
166
break ;
182
167
}
183
168
}
184
- if (!test_passed)
185
- {
169
+ if (!test_passed) {
186
170
std::cerr << " Test failed\n " ;
187
171
exit (1 );
188
172
}
189
173
}
190
174
191
175
192
- int main (int argc, char **argv)
193
- {
176
+ int main (int argc, char **argv) {
194
177
int M = 16 ;
195
178
int efConstruction = 200 ;
196
179
int num_threads = std::thread::hardware_concurrency ();
197
180
198
181
bool update = false ;
199
182
200
- if (argc == 2 )
201
- {
202
- if (std::string (argv[1 ]) == " update" )
203
- {
183
+ if (argc == 2 ) {
184
+ if (std::string (argv[1 ]) == " update" ) {
204
185
update = true ;
205
186
std::cout << " Updates are on\n " ;
206
- }
207
- else {
208
- std::cout<<" Usage ./test_updates [update]\n " ;
187
+ } else {
188
+ std::cout << " Usage ./test_updates [update]\n " ;
209
189
exit (1 );
210
190
}
211
- }
212
- else if (argc>2 ){
213
- std::cout<<" Usage ./test_updates [update]\n " ;
191
+ } else if (argc > 2 ) {
192
+ std::cout << " Usage ./test_updates [update]\n " ;
214
193
exit (1 );
215
194
}
216
195
@@ -224,8 +203,7 @@ int main(int argc, char **argv)
224
203
{
225
204
std::ifstream configfile;
226
205
configfile.open (path + " /config.txt" );
227
- if (!configfile.is_open ())
228
- {
206
+ if (!configfile.is_open ()) {
229
207
std::cout << " Cannot open config.txt\n " ;
230
208
return 1 ;
231
209
}
@@ -245,10 +223,9 @@ int main(int argc, char **argv)
245
223
246
224
StopW stopw = StopW ();
247
225
248
- if (update)
249
- {
226
+ if (update) {
250
227
std::cout << " Update iteration 0\n " ;
251
-
228
+
252
229
ParallelFor (1 , N, num_threads, [&](size_t i, size_t threadId) {
253
230
appr_alg.addPoint ((void *)(dummy_batch.data () + i * d), i);
254
231
});
@@ -259,14 +236,13 @@ int main(int argc, char **argv)
259
236
});
260
237
appr_alg.checkIntegrity ();
261
238
262
- for (int b = 1 ; b < dummy_data_multiplier; b++)
263
- {
239
+ for (int b = 1 ; b < dummy_data_multiplier; b++) {
264
240
std::cout << " Update iteration " << b << " \n " ;
265
241
char cpath[1024 ];
266
242
sprintf (cpath, " batch_dummy_%02d.bin" , b);
267
243
std::vector<float > dummy_batchb = load_batch<float >(path + cpath, N * d);
268
-
269
- ParallelFor (0 , N, num_threads, [&](size_t i, size_t threadId) {
244
+
245
+ ParallelFor (0 , N, num_threads, [&](size_t i, size_t threadId) {
270
246
appr_alg.addPoint ((void *)(dummy_batch.data () + i * d), i);
271
247
});
272
248
appr_alg.checkIntegrity ();
@@ -275,31 +251,28 @@ int main(int argc, char **argv)
275
251
276
252
std::cout << " Inserting final elements\n " ;
277
253
std::vector<float > final_batch = load_batch<float >(path + " batch_final.bin" , N * d);
278
-
254
+
279
255
stopw.reset ();
280
256
ParallelFor (0 , N, num_threads, [&](size_t i, size_t threadId) {
281
257
appr_alg.addPoint ((void *)(final_batch.data () + i * d), i);
282
258
});
283
- std::cout<< " Finished. Time taken:" << stopw.getElapsedTimeMicro ()*1e-6 << " s\n " ;
259
+ std::cout << " Finished. Time taken:" << stopw.getElapsedTimeMicro ()*1e-6 << " s\n " ;
284
260
std::cout << " Running tests\n " ;
285
261
std::vector<float > queries_batch = load_batch<float >(path + " queries.bin" , N_queries * d);
286
262
287
263
std::vector<int > gt = load_batch<int >(path + " gt.bin" , N_queries * K);
288
264
289
265
std::vector<std::unordered_set<hnswlib::labeltype>> answers (N_queries);
290
- for (int i = 0 ; i < N_queries; i++)
291
- {
292
- for (int j = 0 ; j < K; j++)
293
- {
266
+ for (int i = 0 ; i < N_queries; i++) {
267
+ for (int j = 0 ; j < K; j++) {
294
268
answers[i].insert (gt[i * K + j]);
295
269
}
296
270
}
297
271
298
- for (int i = 0 ; i < 3 ; i++)
299
- {
272
+ for (int i = 0 ; i < 3 ; i++) {
300
273
std::cout << " Test iteration " << i << " \n " ;
301
274
test_vs_recall (queries_batch, N_queries, appr_alg, d, answers, K);
302
275
}
303
276
304
277
return 0 ;
305
- };
278
+ }
0 commit comments