Skip to content

Commit 2d07f81

Browse files
authored
JAX zero copy (#5703)
* Add explicit copy argument to `_to_jax_array`. * Don't copy when using `exec_dynamic`. * Fix JAX array device handling for various versions of JAX. --------- Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent a2d1bb1 commit 2d07f81

File tree

10 files changed

+99
-61
lines changed

10 files changed

+99
-61
lines changed

dali/python/nvidia/dali/plugin/jax/clu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class DALIGenericPeekableIterator(DALIGenericIterator):
8282
is called internally automatically.
8383
last_batch_policy: optional, default = LastBatchPolicy.FILL
8484
What to do with the last batch when there are not enough samples in the epoch
85-
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`
85+
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`.
8686
JAX iterator does not support LastBatchPolicy.PARTIAL
8787
last_batch_padded : bool, optional, default = False
8888
Whether the last batch provided by DALI is padded with the last sample

dali/python/nvidia/dali/plugin/jax/integration.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,46 @@
1616
import jax.dlpack
1717

1818
from nvidia.dali.backend import TensorGPU
19+
from packaging.version import Version
1920

2021

21-
_jax_version_pre_0_4_16 = None
22+
_jax_has_old_dlpack = Version(jax.__version__) < Version("0.4.16")
2223

2324

24-
def _jax_has_old_dlpack():
25-
global _jax_version_pre_0_4_16
26-
if _jax_version_pre_0_4_16 is not None:
27-
return _jax_version_pre_0_4_16
25+
if Version(jax.__version__) >= Version("0.4.31"):
2826

29-
from packaging.version import Version
27+
def _jax_device(jax_array):
28+
return jax_array.device
3029

31-
_jax_version_pre_0_4_16 = Version(jax.__version__) < Version("0.4.16")
32-
return _jax_version_pre_0_4_16
30+
elif Version(jax.__version__) >= Version("0.4.27"):
3331

32+
def _jax_device(jax_array):
33+
devs = jax_array.devices()
34+
if len(devs) != 1:
35+
raise RuntimeError("The array must be associated with exactly one device")
36+
for d in devs:
37+
return d
3438

35-
def _to_jax_array(dali_tensor: TensorGPU) -> jax.Array:
39+
else:
40+
41+
def _jax_device(jax_array):
42+
return jax_array.device()
43+
44+
45+
def _to_jax_array(dali_tensor: TensorGPU, copy: bool) -> jax.Array:
3646
"""Converts input DALI tensor to JAX array.
3747
3848
Args:
39-
dali_tensor (TensorGPU): DALI GPU tensor to be converted to JAX array.
49+
dali_tensor (TensorGPU):
50+
DALI GPU tensor to be converted to JAX array.
51+
52+
copy (bool):
53+
If True, the output is copied;
54+
if False, the output may wrap DLPack capsule obtained from `dali_tensor`.
4055
4156
Note:
42-
This function performs deep copy of the underlying data. That will change in
43-
future releases.
57+
This function may perform a copy of the data even if `copy==False` when JAX version is
58+
insufficient (<0.4.16)
4459
4560
Warning:
4661
As private this API may change without notice.
@@ -49,12 +64,12 @@ def _to_jax_array(dali_tensor: TensorGPU) -> jax.Array:
4964
jax.Array: JAX array with the same values and backing device as
5065
input DALI tensor.
5166
"""
52-
if _jax_has_old_dlpack():
67+
if _jax_has_old_dlpack:
68+
copy = True
5369
jax_array = jax.dlpack.from_dlpack(dali_tensor.__dlpack__(stream=None))
5470
else:
5571
jax_array = jax.dlpack.from_dlpack(dali_tensor)
5672

57-
# For now we need this copy to make sure that underlying memory is available.
58-
# One solution is to implement full DLPack contract in DALI.
59-
# TODO(awolant): Remove this copy.
60-
return jax_array.copy()
73+
if copy:
74+
jax_array = jax_array.copy()
75+
return jax_array

dali/python/nvidia/dali/plugin/jax/iterator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class DALIGenericIterator(_DaliBaseIterator):
6666
is called internally automatically.
6767
last_batch_policy: optional, default = LastBatchPolicy.FILL
6868
What to do with the last batch when there are not enough samples in the epoch
69-
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`
69+
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`.
7070
JAX iterator does not support LastBatchPolicy.PARTIAL
7171
last_batch_padded : bool, optional, default = False
7272
Whether the last batch provided by DALI is padded with the last sample
@@ -193,7 +193,10 @@ def _gather_outputs_for_category(self, pipelines_outputs, category_id):
193193

194194
for pipeline_id in range(self._num_gpus):
195195
category_outputs.append(
196-
_to_jax_array(pipelines_outputs[pipeline_id][category_id].as_tensor())
196+
_to_jax_array(
197+
pipelines_outputs[pipeline_id][category_id].as_tensor(),
198+
not self._pipes[pipeline_id].exec_dynamic,
199+
)
197200
)
198201

199202
return category_outputs

dali/test/python/jax_plugin/jax_server.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -47,10 +47,10 @@ def print_devices_details(devices_list, process_id):
4747

4848

4949
def test_lax_workflow(process_id):
50-
array_from_dali = dax.integration._to_jax_array(get_dali_tensor_gpu(1, (1), np.int32))
50+
array_from_dali = dax.integration._to_jax_array(get_dali_tensor_gpu(1, (1), np.int32), False)
5151

5252
assert (
53-
array_from_dali.device() == jax.local_devices()[0]
53+
dax.integration._jax_device(array_from_dali) == jax.local_devices()[0]
5454
), "Array should be backed by the device local to current process."
5555

5656
sum_across_devices = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(array_from_dali)
@@ -64,7 +64,7 @@ def test_lax_workflow(process_id):
6464

6565
def run_distributed_sharing_test(sharding, process_id):
6666
dali_local_shard = dax.integration._to_jax_array(
67-
get_dali_tensor_gpu(process_id, (1), np.int32, 0)
67+
get_dali_tensor_gpu(process_id, (1), np.int32, 0), False
6868
)
6969

7070
# Note: we pass only one local shard but the array virtually
@@ -73,12 +73,20 @@ def run_distributed_sharing_test(sharding, process_id):
7373
shape=(2,), sharding=sharding, arrays=[dali_local_shard]
7474
)
7575

76-
# This array should be backed only by one device buffer that holds
77-
# local part of the data. This buffer should be on the local device.
78-
assert len(dali_sharded_array.device_buffers) == 1
79-
assert dali_sharded_array.device_buffer == jnp.array([process_id])
80-
assert dali_sharded_array.device_buffer.device() == jax.local_devices()[0]
81-
assert dali_sharded_array.device_buffer.device() == jax.devices()[process_id]
76+
# device_buffers has been removed
77+
if hasattr(dali_sharded_array, "device_buffers"):
78+
# This array should be backed only by one device buffer that holds
79+
# local part of the data. This buffer should be on the local device.
80+
assert len(dali_sharded_array.device_buffers) == 1
81+
assert dali_sharded_array.addressable_data(0) == jnp.array([process_id])
82+
assert (
83+
dax.integration._jax_device(dali_sharded_array.addressable_data(0))
84+
== jax.local_devices()[0]
85+
)
86+
assert (
87+
dax.integration._jax_device(dali_sharded_array.addressable_data(0))
88+
== jax.devices()[process_id]
89+
)
8290

8391

8492
def test_positional_sharding_workflow(process_id):

dali/test/python/jax_plugin/test_integration.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -35,13 +35,13 @@ def test_dali_tensor_gpu_to_jax_array(dtype, shape, value):
3535
dali_tensor_gpu = get_dali_tensor_gpu(value=value, shape=shape, dtype=dtype)
3636

3737
# when
38-
jax_array = dax.integration._to_jax_array(dali_tensor_gpu)
38+
jax_array = dax.integration._to_jax_array(dali_tensor_gpu, False)
3939

4040
# then
4141
assert jax.numpy.array_equal(jax_array, jax.numpy.full(shape, value, dtype))
4242

4343
# Make sure JAX array is backed by the GPU
44-
assert jax_array.device() == jax.devices()[0]
44+
assert dax.integration._jax_device(jax_array) == jax.devices()[0]
4545

4646

4747
def test_dali_sequential_tensors_to_jax_array():
@@ -56,10 +56,10 @@ def test_dali_sequential_tensors_to_jax_array():
5656
dali_tensor_gpu = pipe.run()[0].as_tensor()
5757

5858
# when
59-
jax_array = dax.integration._to_jax_array(dali_tensor_gpu)
59+
jax_array = dax.integration._to_jax_array(dali_tensor_gpu, False)
6060

6161
# then
62-
assert jax_array.device() == jax.devices()[0]
62+
assert dax.integration._jax_device(jax_array) == jax.devices()[0]
6363

6464
for i in range(batch_size):
6565
assert jax.numpy.array_equal(

dali/test/python/jax_plugin/test_iterator.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -21,10 +21,12 @@
2121

2222
from utils import iterator_function_def
2323

24+
import nvidia.dali.plugin.jax as dax
2425
from nvidia.dali.plugin.jax import DALIGenericIterator
2526
from nvidia.dali.pipeline import pipeline_def
2627
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
2728
from nose_utils import raises
29+
from nose2.tools import params
2830

2931
import itertools
3032

@@ -39,7 +41,7 @@ def run_and_assert_sequential_iterator(iter, num_iters=4):
3941
jax_array = data["data"]
4042

4143
# then
42-
assert jax_array.device() == jax.devices()[0]
44+
assert dax.integration._jax_device(jax_array) == jax.devices()[0]
4345

4446
for i in range(batch_size):
4547
assert jax.numpy.array_equal(
@@ -49,9 +51,12 @@ def run_and_assert_sequential_iterator(iter, num_iters=4):
4951
assert batch_id == num_iters - 1
5052

5153

52-
def test_dali_sequential_iterator():
54+
@params((False,), (True,))
55+
def test_dali_sequential_iterator(exec_dynamic):
5356
# given
54-
pipe = pipeline_def(iterator_function_def)(batch_size=batch_size, num_threads=4, device_id=0)
57+
pipe = pipeline_def(iterator_function_def)(
58+
batch_size=batch_size, num_threads=4, device_id=0, exec_dynamic=exec_dynamic
59+
)
5560
iter = DALIGenericIterator([pipe], ["data"], reader_name="reader")
5661

5762
# then

dali/test/python/jax_plugin/test_iterator_decorator.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
1717

1818
from nvidia.dali.plugin.jax import DALIGenericIterator, data_iterator
1919
from test_iterator import run_and_assert_sequential_iterator
20+
from nose2.tools import params
2021

2122
import inspect
2223

@@ -46,25 +47,29 @@ def iterator_function():
4647
run_and_assert_sequential_iterator(iter)
4748

4849

49-
def test_dali_iterator_decorator_declarative_with_default_args():
50+
@params((False,), (True,))
51+
def test_dali_iterator_decorator_declarative_with_default_args(exec_dynamic):
5052
# given
5153
@data_iterator(output_map=["data"], reader_name="reader")
5254
def iterator_function():
5355
return iterator_function_def()
5456

55-
iter = iterator_function(batch_size=batch_size)
57+
iter = iterator_function(batch_size=batch_size, exec_dynamic=exec_dynamic)
5658

5759
# then
5860
run_and_assert_sequential_iterator(iter)
5961

6062

61-
def test_dali_iterator_decorator_declarative_pipeline_fn_with_argument():
63+
@params((False,), (True,))
64+
def test_dali_iterator_decorator_declarative_pipeline_fn_with_argument(exec_dynamic):
6265
# given
6366
@data_iterator(output_map=["data"], reader_name="reader")
6467
def iterator_function(num_shards):
6568
return iterator_function_def(num_shards=num_shards)
6669

67-
iter = iterator_function(num_shards=2, num_threads=4, device_id=0, batch_size=batch_size)
70+
iter = iterator_function(
71+
num_shards=2, num_threads=4, device_id=0, batch_size=batch_size, exec_dynamic=exec_dynamic
72+
)
6873

6974
# then
7075
run_and_assert_sequential_iterator(iter)
@@ -91,9 +96,10 @@ def test_iterator_decorator_api_match_iterator_init():
9196
iterator_decorator_args.remove("devices")
9297

9398
# then
94-
assert (
95-
iterator_decorator_args == iterator_init_args
96-
), "Arguments for the iterator decorator and the iterator __init__ method do not match"
99+
assert iterator_decorator_args == iterator_init_args, (
100+
f"Arguments for the iterator decorator and the iterator __init__ method do not match:"
101+
f"\n------\n{iterator_decorator_args}\n-- vs --\n{iterator_init_args}\n------"
102+
)
97103

98104
# Get docs for the decorator "Parameters" section
99105
# Skip the first argument, which differs (pipelines vs. pipeline_fn)
@@ -107,6 +113,7 @@ def test_iterator_decorator_api_match_iterator_init():
107113
iterator_init_docs = iterator_init_docs.split("output_map")[1]
108114
iterator_init_docs = iterator_init_docs.split("sharding")[0]
109115

110-
assert (
111-
iterator_decorator_docs == iterator_init_docs
112-
), "Documentation for the iterator decorator and the iterator __init__ method does not match"
116+
assert iterator_decorator_docs == iterator_init_docs, (
117+
"Documentation for the iterator decorator and the iterator __init__ method does not match:"
118+
f"\n------\n{iterator_decorator_docs}\n-- vs --\n{iterator_init_docs}\n------"
119+
)

dali/test/python/jax_plugin/test_multigpu.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -99,8 +99,8 @@ def test_dali_sequential_sharded_tensors_to_jax_sharded_array_manuall():
9999
dali_tensor_gpu_0 = pipe_0.run()[0].as_tensor()
100100
dali_tensor_gpu_1 = pipe_1.run()[0].as_tensor()
101101

102-
jax_shard_0 = dax.integration._to_jax_array(dali_tensor_gpu_0)
103-
jax_shard_1 = dax.integration._to_jax_array(dali_tensor_gpu_1)
102+
jax_shard_0 = dax.integration._to_jax_array(dali_tensor_gpu_0, False)
103+
jax_shard_1 = dax.integration._to_jax_array(dali_tensor_gpu_1, False)
104104

105105
assert jax_shard_0.device() == jax.devices()[0]
106106
assert jax_shard_1.device() == jax.devices()[1]
@@ -224,8 +224,8 @@ def run_sharding_test(sharding):
224224
dali_shard_1 = get_dali_tensor_gpu(1, (1), np.int32, 1)
225225

226226
shards = [
227-
dax.integration._to_jax_array(dali_shard_0),
228-
dax.integration._to_jax_array(dali_shard_1),
227+
dax.integration._to_jax_array(dali_shard_0, False),
228+
dax.integration._to_jax_array(dali_shard_1, False),
229229
]
230230

231231
assert shards[0].device() == jax.devices()[0]

dali/test/python/jax_plugin/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -36,7 +36,7 @@ def get_dali_tensor_gpu(value, shape, dtype, device_id=0) -> TensorGPU:
3636
with provided value.
3737
"""
3838

39-
@pipeline_def(num_threads=1, batch_size=1)
39+
@pipeline_def(num_threads=1, batch_size=1, exec_dynamic=True)
4040
def dali_pipeline():
4141
values = types.Constant(value=np.full(shape, value, dtype), device="gpu")
4242

qa/TL3_JAX_multiprocess/jax_server.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -82,10 +82,10 @@ def run_distributed_sharing_test(sharding, process_id):
8282
dali_local_shards = []
8383
for id, device in enumerate(jax.local_devices()):
8484
current_shard = dax.integration._to_jax_array(
85-
get_dali_tensor_gpu(process_id, (1), np.int32, id)
85+
get_dali_tensor_gpu(process_id, (1), np.int32, id), False
8686
)
8787

88-
assert current_shard.device() == device
88+
assert dax.integration._jax_device(current_shard) == device
8989

9090
dali_local_shards.append(current_shard)
9191

@@ -97,7 +97,7 @@ def run_distributed_sharing_test(sharding, process_id):
9797

9898
for id, buffer in enumerate(dali_sharded_array.device_buffers):
9999
assert buffer == jnp.array([process_id])
100-
assert buffer.device() == jax.local_devices()[id]
100+
assert dax.integration._jax_device(buffer) == jax.local_devices()[id]
101101

102102

103103
def test_positional_sharding_workflow(process_id):

0 commit comments

Comments
 (0)