Skip to content

Commit abe3a08

Browse files
committed
Merge branch 'master' into revert-2485-fs-eire/force-no-thread-pool-if-built-with-openmp
2 parents a3f6fe1 + 5c2e474 commit abe3a08

21 files changed

+1101
-192
lines changed

.gitmodules

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,6 @@
4646
[submodule "cmake/external/wil"]
4747
path = cmake/external/wil
4848
url = https://github.com/microsoft/wil
49-
49+
[submodule "cmake/external/json"]
50+
path = cmake/external/json
51+
url = https://github.com/nlohmann/json

ThirdPartyNotices.txt

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3794,4 +3794,30 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
37943794
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
37953795
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37963796
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3797-
SOFTWARE
3797+
SOFTWARE
3798+
3799+
-----
3800+
3801+
nlohmann/json
3802+
3803+
MIT License
3804+
3805+
Copyright (c) 2013-2019 Niels Lohmann
3806+
3807+
Permission is hereby granted, free of charge, to any person obtaining a copy
3808+
of this software and associated documentation files (the "Software"), to deal
3809+
in the Software without restriction, including without limitation the rights
3810+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
3811+
copies of the Software, and to permit persons to whom the Software is
3812+
furnished to do so, subject to the following conditions:
3813+
3814+
The above copyright notice and this permission notice shall be included in all
3815+
copies or substantial portions of the Software.
3816+
3817+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
3818+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
3819+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
3820+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
3821+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
3822+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3823+
SOFTWARE.

cgmanifest.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,15 @@
437437
},
438438
"type": "git"
439439
}
440+
},
441+
{
442+
"component": {
443+
"git": {
444+
"commitHash": "d98bf0278d6f59a58271425963a8422ff48fe249",
445+
"repositoryUrl": "https://github.com/nlohmann/json.git"
446+
},
447+
"type": "git"
448+
}
440449
}
441450
],
442451
"Version": 1

cmake/external/json

Submodule json added at d98bf02

cmake/onnxruntime_session.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxrun
1515
if(onnxruntime_ENABLE_INSTRUMENT)
1616
target_compile_definitions(onnxruntime_session PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT)
1717
endif()
18-
target_include_directories(onnxruntime_session PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS})
18+
target_include_directories(onnxruntime_session PRIVATE ${ONNXRUNTIME_ROOT} ${PROJECT_SOURCE_DIR}/external/json ${eigen_INCLUDE_DIRS})
1919
add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES})
2020
set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")
2121
if (onnxruntime_USE_CUDA)

onnxruntime/core/graph/model.cc

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,12 @@ Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchema
116116
// TODO: Check if we can upgrade all the current opset 6 models that are being tested
117117
// in CI to opset 7 or above
118118
LOGS(logger, WARNING) << "ONNX Runtime only *guarantees* support for models stamped "
119-
"with opset version 7 or above for opset domain 'ai.onnx'. "
120-
"Please upgrade your model to opset 7 or higher. "
121-
"For now, this opset "
122-
<< version
123-
<< " model may run depending upon legacy support "
124-
"of some older opset version operators.";
119+
"with opset version 7 or above for opset domain 'ai.onnx'. "
120+
"Please upgrade your model to opset 7 or higher. "
121+
"For now, this opset "
122+
<< version
123+
<< " model may run depending upon legacy support "
124+
"of some older opset version operators.";
125125
}
126126
// We need to overwrite the domain here with ("") or else the loop below will try to find ("")
127127
// in the map and if not found (when domain == kOnnxDomainAlias), adds an entry for ("", 11).
@@ -284,10 +284,8 @@ Status Model::Load(std::unique_ptr<ModelProto> p_model_proto, std::shared_ptr<Mo
284284
return Status::OK();
285285
}
286286

287-
template <typename T>
288-
static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
289-
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
290-
const logging::Logger& logger) {
287+
template <typename T, typename Loader>
288+
static Status LoadModelHelper(const T& file_path, Loader loader) {
291289
int fd;
292290
Status status = Env::Default().FileOpenRd(file_path, fd);
293291
if (!status.IsOK()) {
@@ -304,8 +302,8 @@ static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
304302
}
305303
}
306304
try {
307-
status = Model::Load(fd, p_model, local_registries, logger);
308-
} catch (std::exception& ex) {
305+
status = loader(fd);
306+
} catch (const std::exception& ex) {
309307
GSL_SUPPRESS(es .84)
310308
ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd));
311309
return Status(ONNXRUNTIME, FAIL, ex.what());
@@ -318,14 +316,34 @@ static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
318316
return Env::Default().FileClose(fd);
319317
}
320318

319+
template <typename T>
320+
static Status LoadModel(const T& file_path, ONNX_NAMESPACE::ModelProto& model_proto) {
321+
const auto loader = [&model_proto](int fd) {
322+
return Model::Load(fd, model_proto);
323+
};
324+
325+
return LoadModelHelper(file_path, loader);
326+
}
327+
328+
template <typename T>
329+
static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
330+
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
331+
const logging::Logger& logger) {
332+
const auto loader = [&p_model, local_registries, &logger](int fd) {
333+
return Model::Load(fd, p_model, local_registries, logger);
334+
};
335+
336+
return LoadModelHelper(file_path, loader);
337+
}
338+
321339
template <typename T>
322340
static Status SaveModel(Model& model, const T& file_path) {
323341
int fd;
324342
Status status = Env::Default().FileOpenWr(file_path, fd);
325343
ORT_RETURN_IF_ERROR(status);
326344
try {
327345
status = Model::Save(model, fd);
328-
} catch (std::exception& ex) {
346+
} catch (const std::exception& ex) {
329347
GSL_SUPPRESS(es .84)
330348
ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd));
331349
return Status(ONNXRUNTIME, FAIL, ex.what());
@@ -344,6 +362,11 @@ Status Model::Save(Model& model, const std::wstring& file_path) {
344362
}
345363
#endif
346364

365+
Status Model::Load(const std::basic_string<ORTCHAR_T>& file_path,
366+
ONNX_NAMESPACE::ModelProto& model_proto) {
367+
return LoadModel(file_path, model_proto);
368+
}
369+
347370
GSL_SUPPRESS(r .30) // spurious warnings. p_model is potentially reset in the internal call to Load
348371
GSL_SUPPRESS(r .35)
349372
Status Model::Load(const std::basic_string<ORTCHAR_T>& file_path, std::shared_ptr<Model>& p_model,
@@ -356,15 +379,25 @@ Status Model::Save(Model& model, const std::string& file_path) {
356379
return SaveModel(model, file_path);
357380
}
358381

359-
Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ std::shared_ptr<Model>& p_model,
360-
const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) {
361-
std::unique_ptr<ModelProto> modelProto = onnxruntime::make_unique<ModelProto>();
362-
const bool result = modelProto->ParseFromArray(p_bytes, count);
382+
Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) {
383+
const bool result = model_proto.ParseFromArray(p_bytes, count);
363384
if (!result) {
364385
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
365386
}
366387

367-
p_model = std::make_shared<Model>(std::move(modelProto), local_registries, logger);
388+
return Status::OK();
389+
}
390+
391+
Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ std::shared_ptr<Model>& p_model,
392+
const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) {
393+
ModelProto model_proto;
394+
395+
auto status = LoadFromBytes(count, p_bytes, model_proto);
396+
if (!status.IsOK()) {
397+
return status;
398+
}
399+
400+
p_model = std::make_shared<Model>(model_proto, local_registries, logger);
368401

369402
ORT_RETURN_IF_ERROR(p_model->MainGraph().Resolve(true));
370403

@@ -375,16 +408,14 @@ using ::google::protobuf::io::CodedInputStream;
375408
using ::google::protobuf::io::FileInputStream;
376409
using ::google::protobuf::io::ZeroCopyInputStream;
377410

378-
Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
379-
const logging::Logger& logger) {
411+
Status Model::Load(int fd, ONNX_NAMESPACE::ModelProto& model_proto) {
380412
if (fd < 0) {
381413
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "<p_fd> less than 0.");
382414
}
383415

384-
std::unique_ptr<ModelProto> model_proto = onnxruntime::make_unique<ModelProto>();
385416
#if GOOGLE_PROTOBUF_VERSION >= 3002000
386417
FileInputStream fs(fd);
387-
const bool result = model_proto->ParseFromZeroCopyStream(&fs) && fs.GetErrno() == 0;
418+
const bool result = model_proto.ParseFromZeroCopyStream(&fs) && fs.GetErrno() == 0;
388419
if (!result) {
389420
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
390421
}
@@ -402,7 +433,16 @@ Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOp
402433
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
403434
}
404435
#endif
405-
p_model = std::make_shared<Model>(std::move(model_proto), local_registries, logger);
436+
return Status::OK();
437+
}
438+
439+
Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
440+
const logging::Logger& logger) {
441+
ModelProto model_proto;
442+
443+
ORT_RETURN_IF_ERROR(Load(fd, model_proto));
444+
445+
p_model = std::make_shared<Model>(model_proto, local_registries, logger);
406446

407447
ORT_RETURN_IF_ERROR(p_model->MainGraph().Resolve(true));
408448

onnxruntime/core/graph/model.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class Model {
2626
explicit Model(const std::string& graph_name,
2727
bool is_onnx_domain_only,
2828
const logging::Logger& logger)
29-
:Model(graph_name,is_onnx_domain_only, ModelMetaData(),IOnnxRuntimeOpSchemaRegistryList(),{},{},
30-
logger){}
29+
: Model(graph_name, is_onnx_domain_only, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {}, {},
30+
logger) {}
3131

3232
// Construct model from scratch.
3333
explicit Model(const std::string& graph_name,
@@ -105,16 +105,25 @@ class Model {
105105

106106
static common::Status Load(std::istream& model_istream, ONNX_NAMESPACE::ModelProto* p_model_proto);
107107

108+
static common::Status Load(const std::basic_string<ORTCHAR_T>& file_path,
109+
/*out*/ ONNX_NAMESPACE::ModelProto& model_proto);
110+
108111
// TODO(Task:132) Use of shared_ptr<X>* in Load/Save methods is confusing.
109112
static common::Status Load(const std::basic_string<ORTCHAR_T>& file_path,
110113
/*out*/ std::shared_ptr<Model>& p_model,
111114
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
112115
const logging::Logger& logger);
113116

117+
static common::Status Load(int fd, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto);
118+
114119
static common::Status Load(int fd, /*out*/ std::shared_ptr<Model>& p_model,
115120
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
116121
const logging::Logger& logger);
117122

123+
// 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks
124+
static common::Status LoadFromBytes(int count, void* pBytes,
125+
/*out*/ ONNX_NAMESPACE::ModelProto& model_proto);
126+
118127
// 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks
119128
static common::Status LoadFromBytes(int count, void* pBytes, /*out*/ std::shared_ptr<Model>& p_model,
120129
const IOnnxRuntimeOpSchemaRegistryList* local_registries,

onnxruntime/core/platform/env.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,12 @@ class Env {
157157
// \brief returns a provider that will handle telemetry on the current platform
158158
virtual const Telemetry& GetTelemetryProvider() const = 0;
159159

160+
// \brief returns a value for the queried variable name (var_name)
161+
//
162+
// Returns the corresponding value stored in the environment variable if available
163+
// Returns empty string if there is no such environment variable available
164+
virtual std::string GetEnvironmentVar(const std::string& var_name) const = 0;
165+
160166
protected:
161167
Env();
162168

onnxruntime/core/platform/posix/env.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ limitations under the License.
2727
#include <dlfcn.h>
2828
#include <string.h>
2929
#include <thread>
30-
#include <utility> // for std::forward
30+
#include <utility> // for std::forward
3131
#include <vector>
3232
#include <assert.h>
3333

@@ -74,7 +74,7 @@ using ScopedFileDescriptor = ScopedResource<FileDescriptorTraits>;
7474

7575
// non-macro equivalent of TEMP_FAILURE_RETRY, described here:
7676
// https://www.gnu.org/software/libc/manual/html_node/Interrupted-Primitives.html
77-
template<typename TFunc, typename... TFuncArgs>
77+
template <typename TFunc, typename... TFuncArgs>
7878
long int TempFailureRetry(TFunc retriable_operation, TFuncArgs&&... args) {
7979
long int result;
8080
do {
@@ -216,8 +216,8 @@ class PosixEnv : public Env {
216216
}
217217

218218
mapped_memory = MappedMemoryPtr{
219-
reinterpret_cast<char*>(mapped_base) + offset_to_page,
220-
OrtCallbackInvoker{OrtCallback{UnmapFile, new UnmapFileParam{mapped_base, mapped_length}}}};
219+
reinterpret_cast<char*>(mapped_base) + offset_to_page,
220+
OrtCallbackInvoker{OrtCallback{UnmapFile, new UnmapFileParam{mapped_base, mapped_length}}}};
221221

222222
return Status::OK();
223223
}
@@ -318,6 +318,12 @@ class PosixEnv : public Env {
318318
return telemetry_provider_;
319319
}
320320

321+
// \brief returns a value for the queried variable name (var_name)
322+
std::string GetEnvironmentVar(const std::string& var_name) const override {
323+
char* val = getenv(var_name.c_str());
324+
return val == NULL ? std::string() : std::string(val);
325+
}
326+
321327
private:
322328
PosixEnv() = default;
323329
Telemetry telemetry_provider_;

onnxruntime/core/platform/windows/env.cc

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class WindowsEnv : public Env {
129129

130130
size_t total_bytes_read = 0;
131131
while (total_bytes_read < length) {
132-
constexpr DWORD k_max_bytes_to_read = 1 << 30; // read at most 1GB each time
132+
constexpr DWORD k_max_bytes_to_read = 1 << 30; // read at most 1GB each time
133133
const size_t bytes_remaining = length - total_bytes_read;
134134
const DWORD bytes_to_read = static_cast<DWORD>(std::min<size_t>(bytes_remaining, k_max_bytes_to_read));
135135
DWORD bytes_read;
@@ -227,6 +227,33 @@ class WindowsEnv : public Env {
227227
return telemetry_provider_;
228228
}
229229

230+
// \brief returns a value for the queried variable name (var_name)
231+
std::string GetEnvironmentVar(const std::string& var_name) const override {
232+
// Why getenv() should be avoided on Windows:
233+
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/getenv-wgetenv
234+
// Instead use the Win32 API: GetEnvironmentVariableA()
235+
236+
// Max limit of an environment variable on Windows including the null-terminating character
237+
constexpr DWORD kBufferSize = 32767;
238+
239+
// Create buffer to hold the result
240+
char buffer[kBufferSize];
241+
242+
auto char_count = GetEnvironmentVariableA(var_name.c_str(), buffer, kBufferSize);
243+
244+
// Will be > 0 if the API call was successful
245+
if (char_count) {
246+
return std::string(buffer, buffer + char_count);
247+
}
248+
249+
// TODO: Understand the reason for failure by calling GetLastError().
250+
// If it is due to the specified environment variable being found in the environment block,
251+
// GetLastError() returns ERROR_ENVVAR_NOT_FOUND.
252+
// For now, we assume that the environment variable is not found.
253+
254+
return std::string();
255+
}
256+
230257
private:
231258
WindowsEnv()
232259
: GetSystemTimePreciseAsFileTime_(nullptr) {

0 commit comments

Comments
 (0)