Skip to content

Commit f741289

Browse files
authored
added cache version for nuphar JIT binaries (#2646)
* added cache version for nuphar JIT binaries Previously, when the user wrongfully loaded a JIT binary generated from a Nuphar version different from the current used one, she would get mysterious runtime failures, because we didn't perform any version check on JIT binaries. This change added cache versions to the Nuphar runtime and JIT binaries. The Nuphar runtime will issue verbose message that informs the user version-mismatch errors. * address CR feedback * include NUPHAR_CACHE_VERSION in python wheel
1 parent 7c87070 commit f741289

File tree

9 files changed

+123
-23
lines changed

9 files changed

+123
-23
lines changed

cmake/onnxruntime_python.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ endif()
251251

252252
if (onnxruntime_USE_NUPHAR)
253253
file(GLOB onnxruntime_python_nuphar_python_srcs CONFIGURE_DEPENDS
254-
"${ONNXRUNTIME_ROOT}/core/providers/nuphar/scripts/*.*"
254+
"${ONNXRUNTIME_ROOT}/core/providers/nuphar/scripts/*"
255255
)
256256
add_custom_command(
257257
TARGET onnxruntime_pybind11_state POST_BUILD

onnxruntime/core/providers/nuphar/common/nuphar_settings.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ static const std::unordered_set<std::string> valid_keys = {
3434
kNupharIMatMulForceMkl,
3535
kNupharMatmulExec,
3636
kNupharCachePath,
37-
kNupharCacheVersion,
3837
kNupharCacheSoName,
3938
kNupharCacheModelChecksum,
4039
kNupharCacheForceNoJIT,

onnxruntime/core/providers/nuphar/common/nuphar_settings.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ constexpr static const char* kNupharDumpPartition = "nuphar_dump_partition";
1414
constexpr static const char* kNupharDumpFusedNodes = "nuphar_dump_fused_nodes";
1515
constexpr static const char* kNupharMatmulExec = "nuphar_matmul_exec";
1616
constexpr static const char* kNupharCachePath = "nuphar_cache_path";
17-
constexpr static const char* kNupharCacheVersion = "nuphar_cache_version";
1817
constexpr static const char* kNupharCacheSoName = "nuphar_cache_so_name";
1918
constexpr static const char* kNupharCacheModelChecksum = "nuphar_cache_model_checksum";
2019
constexpr static const char* kNupharCacheForceNoJIT = "nuphar_cache_force_no_jit";
@@ -48,13 +47,6 @@ constexpr static const char* kNupharCodeGenTarget = "nuphar_codegen_target";
4847
// Option to control nuphar code to run with parallel schedule
4948
constexpr static const char* kNupharParallelMinWorkloads = "nuphar_parallel_min_workloads";
5049

51-
// cache version number (MAJOR.MINOR.PATCH) following https://semver.org/
52-
// 1. MAJOR version when you make incompatible changes that old cache files no longer work,
53-
// 2. MINOR version when you add functionality in a backwards - compatible manner, and
54-
// 3. PATCH version when you make backwards - compatible bug fixes.
55-
// NOTE this version needs to be updated when generated code may change
56-
constexpr static const char* kNupharCacheVersion_Current = "1.0.0";
57-
5850
constexpr static const char* kNupharCacheSoName_Default = "jit.so";
5951

6052
void CreateNupharCodeGenSettings(const NupharExecutionProviderInfo& info);

onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
#include "core/common/logging/logging.h"
1212
#include "core/platform/env.h"
1313
#include "core/providers/common.h"
14+
#include "core/providers/nuphar/scripts/NUPHAR_CACHE_VERSION"
1415
#include "gsl/gsl"
1516
#include <topi/detail/extern.h>
1617
#include <tvm/ir_pass.h>
1718
#include <experimental/filesystem>
19+
#include <atomic>
1820
#include <fstream>
1921
namespace fs = std::experimental::filesystem;
2022

@@ -27,13 +29,6 @@ static bool GetOrCreateTVMModuleCacheDirectory(fs::path& path, bool create) {
2729
if (!settings.HasOption(kNupharCachePath))
2830
return false;
2931

30-
std::string version;
31-
if (settings.HasOption(kNupharCacheVersion)) {
32-
version = settings.GetOptionValue(kNupharCacheVersion);
33-
} else {
34-
version = kNupharCacheVersion_Current;
35-
}
36-
3732
path = settings.GetOptionValue(kNupharCachePath);
3833
if (!create && !fs::is_directory(path))
3934
return false;
@@ -43,7 +38,7 @@ static bool GetOrCreateTVMModuleCacheDirectory(fs::path& path, bool create) {
4338
throw std::runtime_error("Failed to create directory " + path.string());
4439
}
4540

46-
path.append(version);
41+
path.append(__NUPHAR_CACHE_VERSION__);
4742
if (!create && !fs::is_directory(path))
4843
return false;
4944

@@ -80,6 +75,63 @@ static void* GetFuncFromLibrary(const std::string& so_path, const std::string& f
8075
return func;
8176
}
8277

78+
static void ParseVersion(const char* version, int* major, int* minor, int* patch) {
79+
std::stringstream ss(version);
80+
std::string val;
81+
82+
auto ver_num_fn = [](const std::string& val) {
83+
ORT_ENFORCE(!val.empty(), "Empty version number");
84+
if (val.length() > 1 && val[0] == '0') {
85+
ORT_THROW("Invalid version number: ", val);
86+
}
87+
ORT_ENFORCE(std::all_of(val.begin(), val.end(), [](char c) { return isdigit(c); }),
88+
"Invalid version number: ", val);
89+
return std::stoi(val);
90+
};
91+
92+
std::getline(ss, val, '.');
93+
ORT_ENFORCE(ss.good(), "Invalid version format: ", version);
94+
*major = ver_num_fn(val);
95+
96+
std::getline(ss, val, '.');
97+
*minor = ver_num_fn(val);
98+
99+
std::getline(ss, val);
100+
*patch = ver_num_fn(val);
101+
}
102+
103+
static void VerifyCacheVersion(const std::string& so_path) {
104+
static std::atomic<bool> cache_version_checked{false};
105+
static std::mutex cache_version_mutex;
106+
107+
// make sure we only check cache version once
108+
if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) {
109+
std::lock_guard<std::mutex> lock(cache_version_mutex);
110+
if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) {
111+
cache_version_checked.store(true, std::memory_order::memory_order_release);
112+
// ensure we have _ORTInternal_GetCacheVersion_ function
113+
void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCacheVersion", /*throw_if_not_found*/ true);
114+
ORT_ENFORCE(f, "NULL library function pointer!");
115+
116+
typedef const char* (*GetVersionFunc)();
117+
GetVersionFunc func = reinterpret_cast<GetVersionFunc>(f);
118+
const char* cache_version = func();
119+
ORT_ENFORCE(cache_version, "Null cache version string!");
120+
int cur_major, cur_minor, cur_patch;
121+
ParseVersion(__NUPHAR_CACHE_VERSION__, &cur_major, &cur_minor, &cur_patch);
122+
int cache_major, cache_minor, cache_patch;
123+
ParseVersion(cache_version, &cache_major, &cache_minor, &cache_patch);
124+
125+
// make version check strict until we have thorough design for compatibility
126+
ORT_ENFORCE((cur_major == cache_major) && (cur_minor == cache_minor),
127+
"Current nuphar runtime version (", __NUPHAR_CACHE_VERSION__,
128+
") doesn't match cached dll version (", cache_version, ")");
129+
130+
cache_version_checked = true;
131+
}
132+
}
133+
}
134+
83135
static bool disable_caching_due_to_checksum_failure = false;
84136

85137
static bool VerifyTVMModuleChecksum(const std::string& so_path) {
@@ -131,6 +183,8 @@ tvm::runtime::PackedFunc LoadTVMPackedFuncFromCache(const std::string& func_name
131183
if (!GetCacheSoFilePath(so_path))
132184
return nullptr;
133185

186+
VerifyCacheVersion(so_path);
187+
134188
if (!VerifyTVMModuleChecksum(so_path))
135189
return nullptr;
136190

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// cache version number (MAJOR.MINOR.PATCH) following https://semver.org/
2+
// 1. MAJOR version when you make incompatible changes that old cache files no longer work,
3+
// 2. MINOR version when you add functionality in a backwards - compatible manner, and
4+
// 3. PATCH version when you make backwards - compatible bug fixes.
5+
// NOTE this version needs to be updated when generated code may change
6+
7+
#ifndef __NUPHAR_CACHE_VERSION__
8+
#define __NUPHAR_CACHE_VERSION__ "2.3.0"
9+
#endif

onnxruntime/core/providers/nuphar/scripts/create_shared.cmd

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ setlocal EnableDelayedExpansion
66

77
if "%1"=="" goto Usage
88

9+
set SCRIPT_DIR=%~dp0
910
set CACHE_DIR=%~f1
1011
set MODEL_FILE=%~f2
1112

@@ -46,6 +47,16 @@ echo __declspec(dllexport) >>%CHECKSUM_CC%
4647
echo void _ORTInternal_GetCheckSum(const char*^& cs, size_t^& len) { >> %CHECKSUM_CC%
4748
echo cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;} >>%CHECKSUM_CC%
4849

50+
REM generate cache version
51+
set CACHE_VERSION_CC=%CACHE_DIR%\cache_version.cc
52+
set VERSION_FILE=%SCRIPT_DIR%NUPHAR_CACHE_VERSION
53+
echo Generating %CACHE_VERSION_CC%...
54+
echo #include "%VERSION_FILE%" >%CACHE_VERSION_CC%
55+
echo extern "C" >>%CACHE_VERSION_CC%
56+
echo __declspec(dllexport) >>%CACHE_VERSION_CC%
57+
echo const char* _ORTInternal_GetCacheVersion() { >> %CACHE_VERSION_CC%
58+
echo return __NUPHAR_CACHE_VERSION__;} >>%CACHE_VERSION_CC%
59+
4960
:Compile
5061
cd /d %CACHE_DIR%
5162
for /f %%i in ('dir /b *.cc') do (
@@ -61,4 +72,4 @@ exit /b
6172
:Usage
6273
echo Usage: %0 cache_dir [model_file] [output_dll]
6374
echo The generated file would be cache_dir\output_dll
64-
exit /b
75+
exit /b

onnxruntime/core/providers/nuphar/scripts/create_shared.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ def gen_checksum(file_checksum, input_dir):
3838
print(' cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;', file=checksum_cc)
3939
print('}', file=checksum_cc)
4040

41+
def gen_cache_version(input_dir):
42+
name = 'ORTInternal_cache_version'
43+
with open(os.path.join(input_dir, name + '.cc'), 'w') as cache_version_cc:
44+
header_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NUPHAR_CACHE_VERSION')
45+
print('#include "{}"'.format(header_file), file=cache_version_cc)
46+
print('extern "C"', file=cache_version_cc)
47+
if is_windows():
48+
print('__declspec(dllexport)', file=cache_version_cc)
49+
print('const char* _ORTInternal_GetCacheVersion() {', file=cache_version_cc)
50+
print(' return __NUPHAR_CACHE_VERSION__;', file=cache_version_cc)
51+
print('}', file=cache_version_cc)
52+
4153
def compile_all_cc(path):
4254
for f in os.listdir(path):
4355
name, ext = os.path.splitext(f)
@@ -65,6 +77,8 @@ def parse_arguments():
6577
input_checksum = gen_md5(args.input_model)
6678
gen_checksum(input_checksum, args.input_dir)
6779

80+
gen_cache_version(args.input_dir)
81+
6882
if is_windows():
6983
# create dllmain
7084
name = 'ORTInternal_dllmain'
@@ -85,4 +99,4 @@ def parse_arguments():
8599

86100
if not args.keep_input:
87101
for f in objs:
88-
os.remove(os.path.join(args.input_dir, f))
102+
os.remove(os.path.join(args.input_dir, f))

onnxruntime/core/providers/nuphar/scripts/create_shared.sh

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
set -x -e -o pipefail
66

7+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"
8+
79
function usage {
810
echo Usage: create_shared.sh -c cache_dir -m input_model_file -o output_so_file
911
echo The generated file would be cache_dir/output_so_file
@@ -34,6 +36,8 @@ if ! [ -x "$(command -v g++)" ]; then
3436
exit 1
3537
fi
3638

39+
declare -a all_cc_files
40+
3741
cd $CACHE_DIR
3842
if [ -x "$MODEL_FILE" ]; then
3943
# generate checksum.cc
@@ -46,10 +50,25 @@ void _ORTInternal_GetCheckSum(const char*& cs, size_t& len) {
4650
cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;
4751
}
4852
__EOF__
49-
g++ -std=c++14 -fPIC -o checksum.o -c checksum.cc
50-
rm checksum.cc
53+
all_cc_files+=(checksum)
5154
fi
5255

56+
# generate cache_version.cc
57+
VERSION_FILE="${SCRIPT_DIR}/NUPHAR_CACHE_VERSION"
58+
cat > $CACHE_DIR/cache_version.cc <<__EOF__
59+
#include "$VERSION_FILE"
60+
extern "C"
61+
const char* _ORTInternal_GetCacheVersion() {
62+
return __NUPHAR_CACHE_VERSION__;
63+
}
64+
__EOF__
65+
all_cc_files+=(cache_version)
66+
67+
for cc_file in "${all_cc_files[@]}"; do
68+
g++ -std=c++14 -fPIC -o "$cc_file".o -c "$cc_file".cc
69+
rm "$cc_file".cc
70+
done
71+
5372
# link
5473
if ls *.o 1> /dev/null 2>&1; then
5574
OBJS=""
@@ -61,4 +80,4 @@ if ls *.o 1> /dev/null 2>&1; then
6180
g++ -shared -fPIC -o $CACHE_DIR/$OUTPUT_SO_FILE $OBJS
6281
fi
6382
rm *.o
64-
fi
83+
fi

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ def run(self):
165165

166166
# Extra files such as EULA and ThirdPartyNotices
167167
extra = ["LICENSE", "ThirdPartyNotices.txt", "Privacy.md"]
168+
if package_name == 'onnxruntime-nuphar':
169+
extra.extend([path.join('nuphar', 'NUPHAR_CACHE_VERSION')])
168170

169171
# Description
170172
README = path.join(getcwd(), "docs/python/README.rst")

0 commit comments

Comments
 (0)