Skip to content

Commit 233f49d

Browse files
committed
Expose GetSessionOptions in pybind logic and add unit test for python
1 parent 14ee2f7 commit 233f49d

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,9 @@ including arg name, arg type (contains both type and shape).)pbdoc")
779779
.def("get_providers", [](InferenceSession* sess) -> const std::vector<std::string>& {
780780
return sess->GetRegisteredProviderTypes();
781781
})
782+
.def_property_readonly("session_options", [](InferenceSession* sess) -> const SessionOptions& {
783+
return sess->GetSessionOptions();
784+
})
782785
.def_property_readonly("inputs_meta", [](const InferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
783786
auto res = sess->GetModelInputs();
784787
OrtPybindThrowIfError(res.first);

onnxruntime/python/session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _load_model(self, providers=[]):
4040

4141
self._sess.load_model(providers)
4242

43+
self._session_options = self._sess.session_options
4344
self._inputs_meta = self._sess.inputs_meta
4445
self._outputs_meta = self._sess.outputs_meta
4546
self._overridable_initializers = self._sess.overridable_initializers
@@ -63,6 +64,10 @@ def _reset_session(self):
6364
self._providers = None
6465
self._sess = None
6566

67+
def get_session_options(self):
68+
"Return the session options. See :class:`onnxruntime.SessionOptions`."
69+
return self._session_options
70+
6671
def get_inputs(self):
6772
"Return the inputs metadata as a list of :class:`onnxruntime.NodeArg`."
6873
return self._inputs_meta

onnxruntime/test/python/onnxruntime_test_python.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,5 +594,28 @@ def testOrtExecutionMode(self):
594594
opt.execution_mode = onnxrt.ExecutionMode.ORT_PARALLEL
595595
self.assertEqual(opt.execution_mode, onnxrt.ExecutionMode.ORT_PARALLEL)
596596

597+
def testLoadingSessionOptionsFromModel(self):
598+
try:
599+
os.environ['ORT_LOAD_CONFIG_FROM_MODEL'] = str(1)
600+
sess = onnxrt.InferenceSession(self.get_name("model_with_valid_ort_config_json.onnx"))
601+
session_options = sess.get_session_options()
602+
603+
self.assertEqual(session_options.inter_op_num_threads, 5)
604+
605+
self.assertEqual(session_options.intra_op_num_threads, 2)
606+
607+
self.assertEqual(session_options.execution_mode, onnxrt.ExecutionMode.ORT_PARALLEL)
608+
609+
self.assertEqual(session_options.graph_optimization_level, 3)
610+
611+
self.assertEqual(session_options.enable_profiling, True)
612+
613+
except Exception:
614+
raise
615+
616+
finally:
617+
# Make sure the usage of the feature is disabled after this test
618+
os.environ['ORT_LOAD_CONFIG_FROM_MODEL'] = str(0)
619+
597620
if __name__ == '__main__':
598621
unittest.main()

0 commit comments

Comments
 (0)