Skip to content

Commit dab9e99

Browse files
committed
Refactoring, test updates
1 parent e97b37c commit dab9e99

File tree

11 files changed

+825
-846
lines changed

11 files changed

+825
-846
lines changed

.github/workflows/build.yml

+16-3
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@ jobs:
1919
run: python -m pip install .
2020

2121
- name: Test
22-
run: python -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py"
22+
run: python -m unittest discover -v --start-directory python_bindings/tests --pattern "*_test*.py"
2323

2424
test_cpp:
25-
runs-on: ubuntu-latest
25+
runs-on: ${{matrix.os}}
26+
strategy:
27+
matrix:
28+
os: [ubuntu-latest, windows-latest]
2629
steps:
2730
- uses: actions/checkout@v3
2831
- uses: actions/setup-python@v4
@@ -34,17 +37,27 @@ jobs:
3437
mkdir build
3538
cd build
3639
cmake ..
37-
make
40+
if [ "$RUNNER_OS" == "Linux" ]; then
41+
make
42+
elif [ "$RUNNER_OS" == "Windows" ]; then
43+
cmake --build ./ --config Release
44+
fi
45+
shell: bash
3846

3947
- name: Prepare test data
4048
run: |
4149
pip install numpy
4250
cd examples
4351
python update_gen_data.py
52+
shell: bash
4453

4554
- name: Test
4655
run: |
4756
cd build
57+
if [ "$RUNNER_OS" == "Windows" ]; then
58+
cp ./Release/* ./
59+
fi
4860
./searchKnnCloserFirst_test
4961
./test_updates
5062
./test_updates update
63+
shell: bash

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ hnswlib.cpython*.so
88
var/
99
.idea/
1010
.vscode/
11-
11+
.vs/

examples/git_tester.py

+33-27
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,40 @@
1+
import os
2+
import shutil
3+
4+
from sys import platform
15
from pydriller import Repository
2-
import os
3-
import datetime
4-
os.system("cp examples/speedtest.py examples/speedtest2.py") # the file has to be outside of git
5-
for idx, commit in enumerate(Repository('.', from_tag="v0.6.0").traverse_commits()):
6-
name=commit.msg.replace('\n', ' ').replace('\r', ' ')
7-
print(idx, commit.hash, name)
86

97

8+
speedtest_src_path = os.path.join("examples", "speedtest.py")
9+
speedtest_copy_path = os.path.join("examples", "speedtest2.py")
10+
shutil.copyfile(speedtest_src_path, speedtest_copy_path) # the file has to be outside of git
11+
12+
commits = list(Repository('.', from_tag="v0.6.0").traverse_commits())
13+
print("Found commits:")
14+
for idx, commit in enumerate(commits):
15+
name = commit.msg.replace('\n', ' ').replace('\r', ' ')
16+
print(idx, commit.hash, name)
1017

11-
for commit in Repository('.', from_tag="v0.6.0").traverse_commits():
12-
13-
name=commit.msg.replace('\n', ' ').replace('\r', ' ')
14-
print(commit.hash, name)
15-
16-
os.system(f"git checkout {commit.hash}; rm -rf build; ")
18+
for commit in commits:
19+
name = commit.msg.replace('\n', ' ').replace('\r', ' ')
20+
print("\nProcessing", commit.hash, name)
21+
22+
if os.path.exists("build"):
23+
shutil.rmtree("build")
24+
os.system(f"git checkout {commit.hash}")
1725
print("\n\n--------------------\n\n")
18-
ret=os.system("python -m pip install .")
19-
print(ret)
20-
21-
if ret != 0:
22-
print ("build failed!!!!")
23-
print ("build failed!!!!")
24-
print ("build failed!!!!")
25-
print ("build failed!!!!")
26-
continue
27-
28-
os.system(f'python examples/speedtest2.py -n "{name}" -d 4 -t 1')
29-
os.system(f'python examples/speedtest2.py -n "{name}" -d 64 -t 1')
30-
os.system(f'python examples/speedtest2.py -n "{name}" -d 128 -t 1')
31-
os.system(f'python examples/speedtest2.py -n "{name}" -d 4 -t 24')
32-
os.system(f'python examples/speedtest2.py -n "{name}" -d 128 -t 24')
26+
ret = os.system("python -m pip install .")
27+
print("Install result:", ret)
3328

29+
if ret != 0:
30+
print("build failed!!!!")
31+
print("build failed!!!!")
32+
print("build failed!!!!")
33+
print("build failed!!!!")
34+
continue
3435

36+
os.system(f'python {speedtest_copy_path} -n "{name}" -d 4 -t 1')
37+
os.system(f'python {speedtest_copy_path} -n "{name}" -d 64 -t 1')
38+
os.system(f'python {speedtest_copy_path} -n "{name}" -d 128 -t 1')
39+
os.system(f'python {speedtest_copy_path} -n "{name}" -d 4 -t 24')
40+
os.system(f'python {speedtest_copy_path} -n "{name}" -d 128 -t 24')

examples/updates_test.cpp

+32-28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "../hnswlib/hnswlib.h"
22
#include <thread>
3+
4+
35
class StopW
46
{
57
std::chrono::steady_clock::time_point time_begin;
@@ -22,6 +24,7 @@ class StopW
2224
}
2325
};
2426

27+
2528
/*
2629
* replacement for the openmp '#pragma omp parallel for' directive
2730
* only handles a subset of functionality (no reductions etc)
@@ -81,8 +84,6 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
8184
std::rethrow_exception(lastException);
8285
}
8386
}
84-
85-
8687
}
8788

8889

@@ -94,7 +95,7 @@ std::vector<datatype> load_batch(std::string path, int size)
9495
assert(sizeof(datatype) == 4);
9596

9697
std::ifstream file;
97-
file.open(path);
98+
file.open(path, std::ios::binary);
9899
if (!file.is_open())
99100
{
100101
std::cout << "Cannot open " << path << "\n";
@@ -107,15 +108,14 @@ std::vector<datatype> load_batch(std::string path, int size)
107108
return batch;
108109
}
109110

111+
110112
template <typename d_type>
111113
static float
112114
test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
113115
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K)
114116
{
115117
size_t correct = 0;
116118
size_t total = 0;
117-
//uncomment to test in parallel mode:
118-
119119

120120
for (int i = 0; i < qsize; i++)
121121
{
@@ -137,10 +137,16 @@ test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<
137137
return 1.0f * correct / total;
138138
}
139139

140+
140141
static void
141-
test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<float> &appr_alg, size_t vecdim,
142-
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
143-
{
142+
test_vs_recall(
143+
std::vector<float> &queries,
144+
size_t qsize,
145+
hnswlib::HierarchicalNSW<float> &appr_alg,
146+
size_t vecdim,
147+
std::vector<std::unordered_set<hnswlib::labeltype>> &answers,
148+
size_t k) {
149+
144150
std::vector<size_t> efs = {1};
145151
for (int i = k; i < 30; i++)
146152
{
@@ -155,6 +161,8 @@ test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalN
155161
efs.push_back(i);
156162
}
157163
std::cout << "ef\trecall\ttime\thops\tdistcomp\n";
164+
165+
bool test_passed = false;
158166
for (size_t ef : efs)
159167
{
160168
appr_alg.setEf(ef);
@@ -171,20 +179,24 @@ test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalN
171179
std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
172180
if (recall > 0.99)
173181
{
182+
test_passed = true;
174183
std::cout << "Recall is over 0.99! "<<recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
175184
break;
176185
}
177186
}
187+
if (!test_passed)
188+
{
189+
std::cerr << "Test failed\n";
190+
exit(1);
191+
}
178192
}
179193

194+
180195
int main(int argc, char **argv)
181196
{
182-
183197
int M = 16;
184198
int efConstruction = 200;
185199
int num_threads = std::thread::hardware_concurrency();
186-
187-
188200

189201
bool update = false;
190202

@@ -207,7 +219,6 @@ int main(int argc, char **argv)
207219

208220
std::string path = "../examples/data/";
209221

210-
211222
int N;
212223
int dummy_data_multiplier;
213224
int N_queries;
@@ -216,8 +227,7 @@ int main(int argc, char **argv)
216227
{
217228
std::ifstream configfile;
218229
configfile.open(path + "/config.txt");
219-
if (!configfile.is_open())
220-
{
230+
if (!configfile.is_open()) {
221231
std::cout << "Cannot open config.txt\n";
222232
return 1;
223233
}
@@ -237,11 +247,9 @@ int main(int argc, char **argv)
237247

238248
StopW stopw = StopW();
239249

240-
if (update)
241-
{
250+
if (update) {
242251
std::cout << "Update iteration 0\n";
243252

244-
245253
ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) {
246254
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
247255
});
@@ -252,13 +260,12 @@ int main(int argc, char **argv)
252260
});
253261
appr_alg.checkIntegrity();
254262

255-
for (int b = 1; b < dummy_data_multiplier; b++)
256-
{
263+
for (int b = 1; b < dummy_data_multiplier; b++) {
257264
std::cout << "Update iteration " << b << "\n";
258265
char cpath[1024];
259266
sprintf(cpath, "batch_dummy_%02d.bin", b);
260267
std::vector<float> dummy_batchb = load_batch<float>(path + cpath, N * d);
261-
268+
262269
ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
263270
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
264271
});
@@ -268,7 +275,7 @@ int main(int argc, char **argv)
268275

269276
std::cout << "Inserting final elements\n";
270277
std::vector<float> final_batch = load_batch<float>(path + "batch_final.bin", N * d);
271-
278+
272279
stopw.reset();
273280
ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
274281
appr_alg.addPoint((void *)(final_batch.data() + i * d), i);
@@ -280,19 +287,16 @@ int main(int argc, char **argv)
280287
std::vector<int> gt = load_batch<int>(path + "gt.bin", N_queries * K);
281288

282289
std::vector<std::unordered_set<hnswlib::labeltype>> answers(N_queries);
283-
for (int i = 0; i < N_queries; i++)
284-
{
285-
for (int j = 0; j < K; j++)
286-
{
290+
for (int i = 0; i < N_queries; i++) {
291+
for (int j = 0; j < K; j++) {
287292
answers[i].insert(gt[i * K + j]);
288293
}
289294
}
290295

291-
for (int i = 0; i < 3; i++)
292-
{
296+
for (int i = 0; i < 3; i++) {
293297
std::cout << "Test iteration " << i << "\n";
294298
test_vs_recall(queries_batch, N_queries, appr_alg, d, answers, K);
295299
}
296300

297301
return 0;
298-
};
302+
};

0 commit comments

Comments
 (0)