Skip to content

Row indexing a dataset with numpy integers #7423

Open
@DavidRConnell

Description

@DavidRConnell

Feature request

Allow indexing datasets with a scalar numpy integer type.

Motivation

Indexing a dataset with a scalar numpy.int* object raises a TypeError. This is due to the test in datasets/formatting/formatting.py:key_to_query_type

def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str:
    if isinstance(key, int):
        return "row"
    elif isinstance(key, str):
        return "column"
    elif isinstance(key, (slice, range, Iterable)):
        return "batch"
    _raise_bad_key_type(key)

In the row case, it checks if key is an int, which returns false when key is integer like but not a builtin python integer type. This is counterintuitive because a numpy array of np.int64s can be used for the batch case.

For example:

import numpy as np

import datasets

dataset = datasets.Dataset.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})

# Regular indexing
dataset[0]
dataset[:2]

# Indexing with numpy data types (expect same results)
idx = np.asarray([0, 1])
dataset[idx]  # Succeeds when using an array of np.int64 values
dataset[idx[0]]  # Fails with TypeError when using scalar np.int64

For the user, this can be solved by wrapping idx[0] in int but the test could also be changed in key_to_query_type to accept a less strict definition of int.

+import numbers
+
def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str:
+   if isinstance(key, numbers.Integral):
-   if isinstance(key, int):
        return "row"
    elif isinstance(key, str):
        return "column"
    elif isinstance(key, (slice, range, Iterable)):
        return "batch"
    _raise_bad_key_type(key)

Looking at how others do it, pandas has an is_integer definition that it checks which uses is_integer_object defined in pandas/_libs/utils.pxd:

cdef inline bint is_integer_object(object obj) noexcept:
    """
    Cython equivalent of

    `isinstance(val, (int, np.integer)) and not isinstance(val, (bool, np.timedelta64))`

    Parameters
    ----------
    val : object

    Returns
    -------
    is_integer : bool

    Notes
    -----
    This counts np.timedelta64 objects as integers.
    """
    return (not PyBool_Check(obj) and isinstance(obj, (int, cnp.integer))
            and not is_timedelta64_object(obj))

This would be less flexible as it explicitly checks for numpy integer, but worth noting that they had the need to ensure the key is not a bool.

Your contribution

I can submit a pull request with the above changes after checking that indexing succeeds with the numpy integer type. Or if there is a different integer check that would be preferred I could add that.

If there is a reason not to want this behavior that is fine too.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions