Skip to content

Commit f8a76a6

Browse files
authored
Fix Pipeline reference leak in PythonFunction. (#5668)
* Use a stub pipeline as PythonFunction's "current pipeline" to avoid pipeline self-referencing and self-deleting from within its ThreadPool. --------- Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent 988265a commit f8a76a6

File tree

5 files changed

+53
-31
lines changed

5 files changed

+53
-31
lines changed

dali/operators/python_function/python_function.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -20,7 +20,11 @@ namespace dali {
2020

2121
DALI_SCHEMA(PythonFunctionBase)
2222
.AddArg("function",
23-
"Function object.",
23+
R"code(A callable object that defines the function of the operator.
24+
25+
.. warning::
26+
The function must not hold a reference to the pipeline in which it is used. If it does,
27+
a circular reference to the pipeline will form and the pipeline will never be freed.)code",
2428
DALI_PYTHON_OBJECT)
2529
.AddOptionalArg("num_outputs", R"code(Number of outputs.)code", 1)
2630
.AddOptionalArg<std::vector<TensorLayout>>("output_layouts",
@@ -41,7 +45,7 @@ a more universal data format, see :meth:`nvidia.dali.fn.dl_tensor_python_functio
4145
The function should not modify input tensors.
4246
4347
.. warning::
44-
This operator is not compatible with TensorFlow integration.
48+
This operator is not compatible with TensorFlow integration.
4549
4650
.. warning::
4751
When the pipeline has conditional execution enabled, additional steps must be taken to

dali/python/nvidia/dali/ops/_operators/python_function.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ def __init__(self, function, num_outputs=1, **kwargs):
5151

5252
def __call__(self, *inputs, **kwargs):
5353
inputs = ops._preprocess_inputs(inputs, impl_name, self._device, None)
54-
self.pipeline = _Pipeline.current()
55-
if self.pipeline is None:
54+
curr_pipe = _Pipeline.current()
55+
if curr_pipe is None:
5656
_Pipeline._raise_pipeline_required("PythonFunction operator")
57+
self.pipeline = curr_pipe._stub()
5758

5859
for inp in inputs:
5960
if not isinstance(inp, _DataNode):

dali/python/nvidia/dali/pipeline.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from threading import local as tls
2727
from . import data_node as _data_node
2828
import atexit
29+
import copy
2930
import ctypes
3031
import functools
3132
import inspect
@@ -1764,6 +1765,33 @@ def _generate_build_args(self):
17641765
for (name, dev), dtype, ndim in zip(self._names_and_devices, dtypes, ndims)
17651766
]
17661767

1768+
def _stub(self):
1769+
"""Produce a stub by shallow-copying the pipeline, removing the backend and forbidding
1770+
operations that require the backend.
1771+
1772+
Stub pipelines are necessary in contexts where passing the actual pipeline would cause
1773+
circular reference - notably, PythonFunction operator.
1774+
"""
1775+
stub = copy.copy(self)
1776+
stub._pipe = None
1777+
1778+
def short_circuit(self, *args, **kwargs):
1779+
raise RuntimeError("This method is forbidden in current context")
1780+
1781+
stub.start_py_workers = short_circuit
1782+
stub.build = short_circuit
1783+
stub.run = short_circuit
1784+
stub.schedule_run = short_circuit
1785+
stub.outputs = short_circuit
1786+
stub.share_outputs = short_circuit
1787+
stub.release_outputs = short_circuit
1788+
stub.add_sink = short_circuit
1789+
stub.checkpoint = short_circuit
1790+
stub.set_outputs = short_circuit
1791+
stub.executor_statistics = short_circuit
1792+
stub.external_source_shm_statistics = short_circuit
1793+
return stub
1794+
17671795

17681796
def _shutdown_pipelines():
17691797
for weak in list(Pipeline._pipes):

dali/python/nvidia/dali/plugin/pytorch/_torch_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def torch_wrapper(self, batch_processing, function, device, *args):
5656
)
5757

5858
def __call__(self, *inputs, **kwargs):
59-
pipeline = Pipeline.current()
59+
pipeline = Pipeline.current()._stub()
6060
if pipeline is None:
6161
Pipeline._raise_pipeline_required("TorchPythonFunction")
6262
if self.stream is None:

dali/test/python/operator_2/test_python_function.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -649,30 +649,6 @@ def py_fun_pipeline():
649649
pipe.run()
650650

651651

652-
def verify_pipeline(pipeline, input):
653-
assert pipeline is Pipeline.current()
654-
return input
655-
656-
657-
def test_current_pipeline():
658-
pipe1 = Pipeline(13, 4, 0)
659-
with pipe1:
660-
dummy = types.Constant(numpy.ones((1)))
661-
output = fn.python_function(dummy, function=lambda inp: verify_pipeline(pipe1, inp))
662-
pipe1.set_outputs(output)
663-
664-
pipe2 = Pipeline(6, 2, 0)
665-
with pipe2:
666-
dummy = types.Constant(numpy.ones((1)))
667-
output = fn.python_function(dummy, function=lambda inp: verify_pipeline(pipe2, inp))
668-
pipe2.set_outputs(output)
669-
670-
pipe1.build()
671-
pipe2.build()
672-
pipe1.run()
673-
pipe2.run()
674-
675-
676652
@params(
677653
numpy.bool_,
678654
numpy.int_,
@@ -716,3 +692,16 @@ def test_pipe():
716692
pipe.build()
717693

718694
_ = pipe.run()
695+
696+
697+
def test_delete_pipe_while_function_running():
698+
def func(x):
699+
time.sleep(0.02)
700+
return x
701+
702+
for i in range(5):
703+
with Pipeline(batch_size=1, num_threads=1, device_id=None) as pipe:
704+
pipe.set_outputs(fn.python_function(types.Constant(0), function=func))
705+
pipe.build()
706+
pipe.run()
707+
del pipe

0 commit comments

Comments
 (0)