Description
Feature request
Allow indexing datasets with a scalar numpy integer type.
Motivation
Indexing a dataset with a scalar numpy.int* object raises a TypeError. This is due to the test in datasets/formatting/formatting.py:key_to_query_type
def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str:
if isinstance(key, int):
return "row"
elif isinstance(key, str):
return "column"
elif isinstance(key, (slice, range, Iterable)):
return "batch"
_raise_bad_key_type(key)
In the row case, it checks if key is an int, which returns false when key is integer like but not a builtin python integer type. This is counterintuitive because a numpy array of np.int64s can be used for the batch case.
For example:
import numpy as np
import datasets
dataset = datasets.Dataset.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})
# Regular indexing
dataset[0]
dataset[:2]
# Indexing with numpy data types (expect same results)
idx = np.asarray([0, 1])
dataset[idx] # Succeeds when using an array of np.int64 values
dataset[idx[0]] # Fails with TypeError when using scalar np.int64
For the user, this can be solved by wrapping idx[0]
in int
but the test could also be changed in key_to_query_type
to accept a less strict definition of int.
+import numbers
+
def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str:
+ if isinstance(key, numbers.Integral):
- if isinstance(key, int):
return "row"
elif isinstance(key, str):
return "column"
elif isinstance(key, (slice, range, Iterable)):
return "batch"
_raise_bad_key_type(key)
Looking at how others do it, pandas has an is_integer
definition that it checks which uses is_integer_object
defined in pandas/_libs/utils.pxd
:
cdef inline bint is_integer_object(object obj) noexcept:
"""
Cython equivalent of
`isinstance(val, (int, np.integer)) and not isinstance(val, (bool, np.timedelta64))`
Parameters
----------
val : object
Returns
-------
is_integer : bool
Notes
-----
This counts np.timedelta64 objects as integers.
"""
return (not PyBool_Check(obj) and isinstance(obj, (int, cnp.integer))
and not is_timedelta64_object(obj))
This would be less flexible as it explicitly checks for numpy integer, but worth noting that they had the need to ensure the key is not a bool.
Your contribution
I can submit a pull request with the above changes after checking that indexing succeeds with the numpy integer type. Or if there is a different integer check that would be preferred I could add that.
If there is a reason not to want this behavior that is fine too.