@@ -116,12 +116,12 @@ Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchema
116
116
// TODO: Check if we can upgrade all the current opset 6 models that are being tested
117
117
// in CI to opset 7 or above
118
118
LOGS (logger, WARNING) << " ONNX Runtime only *guarantees* support for models stamped "
119
- " with opset version 7 or above for opset domain 'ai.onnx'. "
120
- " Please upgrade your model to opset 7 or higher. "
121
- " For now, this opset "
122
- << version
123
- << " model may run depending upon legacy support "
124
- " of some older opset version operators." ;
119
+ " with opset version 7 or above for opset domain 'ai.onnx'. "
120
+ " Please upgrade your model to opset 7 or higher. "
121
+ " For now, this opset "
122
+ << version
123
+ << " model may run depending upon legacy support "
124
+ " of some older opset version operators." ;
125
125
}
126
126
// We need to overwrite the domain here with ("") or else the loop below will try to find ("")
127
127
// in the map and if not found (when domain == kOnnxDomainAlias), adds an entry for ("", 11).
@@ -284,10 +284,8 @@ Status Model::Load(std::unique_ptr<ModelProto> p_model_proto, std::shared_ptr<Mo
284
284
return Status::OK ();
285
285
}
286
286
287
- template <typename T>
288
- static Status LoadModel (const T& file_path, std::shared_ptr<Model>& p_model,
289
- const IOnnxRuntimeOpSchemaRegistryList* local_registries,
290
- const logging::Logger& logger) {
287
+ template <typename T, typename Loader>
288
+ static Status LoadModelHelper (const T& file_path, Loader loader) {
291
289
int fd;
292
290
Status status = Env::Default ().FileOpenRd (file_path, fd);
293
291
if (!status.IsOK ()) {
@@ -304,8 +302,8 @@ static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
304
302
}
305
303
}
306
304
try {
307
- status = Model::Load (fd, p_model, local_registries, logger );
308
- } catch (std::exception& ex) {
305
+ status = loader (fd);
306
+ } catch (const std::exception& ex) {
309
307
GSL_SUPPRESS (es .84 )
310
308
ORT_IGNORE_RETURN_VALUE (Env::Default ().FileClose (fd));
311
309
return Status (ONNXRUNTIME, FAIL, ex.what ());
@@ -318,14 +316,34 @@ static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
318
316
return Env::Default ().FileClose (fd);
319
317
}
320
318
319
+ template <typename T>
320
+ static Status LoadModel (const T& file_path, ONNX_NAMESPACE::ModelProto& model_proto) {
321
+ const auto loader = [&model_proto](int fd) {
322
+ return Model::Load (fd, model_proto);
323
+ };
324
+
325
+ return LoadModelHelper (file_path, loader);
326
+ }
327
+
328
+ template <typename T>
329
+ static Status LoadModel (const T& file_path, std::shared_ptr<Model>& p_model,
330
+ const IOnnxRuntimeOpSchemaRegistryList* local_registries,
331
+ const logging::Logger& logger) {
332
+ const auto loader = [&p_model, local_registries, &logger](int fd) {
333
+ return Model::Load (fd, p_model, local_registries, logger);
334
+ };
335
+
336
+ return LoadModelHelper (file_path, loader);
337
+ }
338
+
321
339
template <typename T>
322
340
static Status SaveModel (Model& model, const T& file_path) {
323
341
int fd;
324
342
Status status = Env::Default ().FileOpenWr (file_path, fd);
325
343
ORT_RETURN_IF_ERROR (status);
326
344
try {
327
345
status = Model::Save (model, fd);
328
- } catch (std::exception& ex) {
346
+ } catch (const std::exception& ex) {
329
347
GSL_SUPPRESS (es .84 )
330
348
ORT_IGNORE_RETURN_VALUE (Env::Default ().FileClose (fd));
331
349
return Status (ONNXRUNTIME, FAIL, ex.what ());
@@ -344,6 +362,11 @@ Status Model::Save(Model& model, const std::wstring& file_path) {
344
362
}
345
363
#endif
346
364
365
+ Status Model::Load (const std::basic_string<ORTCHAR_T>& file_path,
366
+ ONNX_NAMESPACE::ModelProto& model_proto) {
367
+ return LoadModel (file_path, model_proto);
368
+ }
369
+
347
370
GSL_SUPPRESS (r .30 ) // spurious warnings. p_model is potentially reset in the internal call to Load
348
371
GSL_SUPPRESS (r .35 )
349
372
Status Model::Load (const std::basic_string<ORTCHAR_T>& file_path, std::shared_ptr<Model>& p_model,
@@ -356,15 +379,25 @@ Status Model::Save(Model& model, const std::string& file_path) {
356
379
return SaveModel (model, file_path);
357
380
}
358
381
359
- Status Model::LoadFromBytes (int count, void * p_bytes, /* out*/ std::shared_ptr<Model>& p_model,
360
- const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) {
361
- std::unique_ptr<ModelProto> modelProto = onnxruntime::make_unique<ModelProto>();
362
- const bool result = modelProto->ParseFromArray (p_bytes, count);
382
+ Status Model::LoadFromBytes (int count, void * p_bytes, /* out*/ ONNX_NAMESPACE::ModelProto& model_proto) {
383
+ const bool result = model_proto.ParseFromArray (p_bytes, count);
363
384
if (!result) {
364
385
return Status (ONNXRUNTIME, INVALID_PROTOBUF, " Protobuf parsing failed." );
365
386
}
366
387
367
- p_model = std::make_shared<Model>(std::move (modelProto), local_registries, logger);
388
+ return Status::OK ();
389
+ }
390
+
391
+ Status Model::LoadFromBytes (int count, void * p_bytes, /* out*/ std::shared_ptr<Model>& p_model,
392
+ const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) {
393
+ ModelProto model_proto;
394
+
395
+ auto status = LoadFromBytes (count, p_bytes, model_proto);
396
+ if (!status.IsOK ()) {
397
+ return status;
398
+ }
399
+
400
+ p_model = std::make_shared<Model>(model_proto, local_registries, logger);
368
401
369
402
ORT_RETURN_IF_ERROR (p_model->MainGraph ().Resolve (true ));
370
403
@@ -375,16 +408,14 @@ using ::google::protobuf::io::CodedInputStream;
375
408
using ::google::protobuf::io::FileInputStream;
376
409
using ::google::protobuf::io::ZeroCopyInputStream;
377
410
378
- Status Model::Load (int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
379
- const logging::Logger& logger) {
411
+ Status Model::Load (int fd, ONNX_NAMESPACE::ModelProto& model_proto) {
380
412
if (fd < 0 ) {
381
413
return Status (ONNXRUNTIME, INVALID_ARGUMENT, " <p_fd> less than 0." );
382
414
}
383
415
384
- std::unique_ptr<ModelProto> model_proto = onnxruntime::make_unique<ModelProto>();
385
416
#if GOOGLE_PROTOBUF_VERSION >= 3002000
386
417
FileInputStream fs (fd);
387
- const bool result = model_proto-> ParseFromZeroCopyStream (&fs) && fs.GetErrno () == 0 ;
418
+ const bool result = model_proto. ParseFromZeroCopyStream (&fs) && fs.GetErrno () == 0 ;
388
419
if (!result) {
389
420
return Status (ONNXRUNTIME, INVALID_PROTOBUF, " Protobuf parsing failed." );
390
421
}
@@ -402,7 +433,16 @@ Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOp
402
433
return Status (ONNXRUNTIME, INVALID_PROTOBUF, " Protobuf parsing failed." );
403
434
}
404
435
#endif
405
- p_model = std::make_shared<Model>(std::move (model_proto), local_registries, logger);
436
+ return Status::OK ();
437
+ }
438
+
439
+ Status Model::Load (int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
440
+ const logging::Logger& logger) {
441
+ ModelProto model_proto;
442
+
443
+ ORT_RETURN_IF_ERROR (Load (fd, model_proto));
444
+
445
+ p_model = std::make_shared<Model>(model_proto, local_registries, logger);
406
446
407
447
ORT_RETURN_IF_ERROR (p_model->MainGraph ().Resolve (true ));
408
448
0 commit comments