Skip to content

Commit 529d795

Browse files
authored
Align autogenerated dimension names when dims and default_dims are provided (#2395)
* Align autogenerated dimension names when dims and default_dims are provided * Add to CHANGELOG.md
1 parent 0868c9e commit 529d795

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Change Log
22

3+
## Unreleased
4+
5+
### Maintenance and fixes
6+
- Make `arviz.data.generate_dims_coords` handle `dims` and `default_dims` consistently ([2395](https://github.com/arviz-devs/arviz/pull/2395))
7+
38
## v0.20.0 (2024 Sep 28)
49

510
### New features

arviz/data/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,10 @@ def generate_dims_coords(
201201
for i, dim_len in enumerate(shape):
202202
idx = i + len([dim for dim in default_dims if dim in dims])
203203
if len(dims) < idx + 1:
204-
dim_name = f"{var_name}_dim_{idx}"
204+
dim_name = f"{var_name}_dim_{i}"
205205
dims.append(dim_name)
206206
elif dims[idx] is None:
207-
dim_name = f"{var_name}_dim_{idx}"
207+
dim_name = f"{var_name}_dim_{i}"
208208
dims[idx] = dim_name
209209
dim_name = dims[idx]
210210
if dim_name not in coords:

arviz/tests/base_tests/test_data.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@
3232
extract,
3333
)
3434

35-
from ...data.base import dict_to_dataset, generate_dims_coords, infer_stan_dtypes, make_attrs
35+
from ...data.base import (
36+
dict_to_dataset,
37+
generate_dims_coords,
38+
infer_stan_dtypes,
39+
make_attrs,
40+
numpy_to_data_array,
41+
)
3642
from ...data.datasets import LOCAL_DATASETS, REMOTE_DATASETS, RemoteFileMetadata
3743
from ..helpers import ( # pylint: disable=unused-import
3844
chains,
@@ -231,6 +237,17 @@ def test_dims_coords_skip_event_dims(shape):
231237
assert "z" not in coords
232238

233239

240+
@pytest.mark.parametrize("dims", [None, ["chain", "draw"], ["chain", "draw", None]])
241+
def test_numpy_to_data_array_with_dims(dims):
242+
da = numpy_to_data_array(
243+
np.empty((4, 500, 7)),
244+
var_name="a",
245+
dims=dims,
246+
default_dims=["chain", "draw"],
247+
)
248+
assert list(da.dims) == ["chain", "draw", "a_dim_0"]
249+
250+
234251
def test_make_attrs():
235252
extra_attrs = {"key": "Value"}
236253
attrs = make_attrs(attrs=extra_attrs)

0 commit comments

Comments
 (0)