Skip to content

Enhance error handling and input validation across multiple modules #7602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/datasets/arrow_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ def _pct_to_abs_pct1(boundary, num_examples):
# Using math.trunc here, since -99.5% should give -99%, not -100%.
if num_examples < 100:
msg = (
'Using "pct1_dropremainder" rounding on a split with less than 100 '
"elements is forbidden: it always results in an empty dataset."
f'Cannot use "pct1_dropremainder" rounding on a small dataset (size={num_examples} < 100). '
'This would result in an empty dataset. Consider using absolute values or a different splitting strategy.'
)
raise ValueError(msg)
return boundary * math.trunc(num_examples / 100.0)
Expand Down
12 changes: 5 additions & 7 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,12 @@ def __getitem__(self, k) -> Dataset:
if isinstance(k, (str, NamedSplit)) or len(self) == 0:
return super().__getitem__(k)
else:
available_suggested_splits = [
split for split in (Split.TRAIN, Split.TEST, Split.VALIDATION) if split in self
]
suggested_split = available_suggested_splits[0] if available_suggested_splits else list(self)[0]
available_splits = [f"'{split}'" for split in sorted(self.keys())]
available_splits_str = ", ".join(available_splits) if available_splits else 'no available splits'
raise KeyError(
f"Invalid key: {k}. Please first select a split. For example: "
f"`my_dataset_dictionary['{suggested_split}'][{k}]`. "
f"Available splits: {sorted(self)}"
f"Invalid key: '{k}'. Expected a split name (str or NamedSplit). "
f"Available splits: {available_splits_str}. "
"Please select a split first, e.g. with dataset_dict['train']"
)

@property
Expand Down
10 changes: 7 additions & 3 deletions src/datasets/features/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,13 @@ def decode_example(
if not self.decode:
raise RuntimeError("Decoding is disabled for this feature. Please use Audio(decode=True) instead.")

path, file = (value["path"], BytesIO(value["bytes"])) if value["bytes"] is not None else (value["path"], None)
if path is None and file is None:
raise ValueError(f"An audio sample should have one of 'path' or 'bytes' but both are None in {value}.")
data = value
if "path" not in data and "bytes" not in data:
raise ValueError(
f"Audio data must contain either 'path' or 'bytes' key with non-None value. "
f"Received input: {data}. "
"Please provide either a file path in 'path' or raw audio data in 'bytes'."
)

try:
import librosa
Expand Down
35 changes: 30 additions & 5 deletions src/datasets/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,20 +350,45 @@ def search(self, query: np.array, k=10, **kwargs) -> SearchResults:
"""Find the nearest examples indices to the query.

Args:
query (`np.array`): The query as a numpy array.
query (`np.array`): The query as a numpy array. Should be a 1D array or a 2D array with shape (1, N).
k (`int`): The number of examples to retrieve.

Output:
scores (`List[List[float]`): The retrieval scores of the retrieved examples.
indices (`List[List[int]]`): The indices of the retrieved examples.

Raises:
ValueError: If the query shape is invalid or doesn't match the index dimensions.
"""
if len(query.shape) != 1 and (len(query.shape) != 2 or query.shape[0] != 1):
raise ValueError("Shape of query is incorrect, it has to be either a 1D array or 2D (1, N)")
if not isinstance(query, np.ndarray):
raise TypeError(f"Query must be a numpy array, got {type(query).__name__}")

if len(query.shape) == 1:
queries = query.reshape(1, -1)
elif len(query.shape) == 2:
if query.shape[0] != 1:
raise ValueError(
f"For 2D queries, the first dimension must be 1 (batch size). Got shape {query.shape}"
)
queries = query
else:
raise ValueError(
f"Query must be 1D or 2D. Got {len(query.shape)}D array with shape {query.shape}"
)

queries = query.reshape(1, -1)
if not queries.flags.c_contiguous:
queries = np.asarray(queries, order="C")
scores, indices = self.faiss_index.search(queries, k, **kwargs)

try:
scores, indices = self.faiss_index.search(queries, k, **kwargs)
except RuntimeError as e:
if "query size" in str(e).lower() or "dimension" in str(e).lower():
raise ValueError(
f"Query dimension mismatch. Expected {self.faiss_index.d} dimensions, "
f"got {queries.shape[1]}. {str(e)}"
) from e
raise

return SearchResults(scores[0], indices[0].astype(int))

def search_batch(self, queries: np.array, k=10, **kwargs) -> BatchedSearchResults:
Expand Down
15 changes: 12 additions & 3 deletions src/datasets/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,18 @@ def __setitem__(self, key, value):
raise ValueError(self._error_msg.format(key=key))
return super().__setitem__(key, value)

def update(self, other):
if any(k in self for k in other):
raise ValueError(self._error_msg.format(key=set(self) & set(other)))
def update(self, *args, **kwargs):
if args:
if len(args) > 1:
raise TypeError(f"update expected at most 1 positional argument, got {len(args)}")
other = args[0]
if not hasattr(other, 'keys'):
raise TypeError(f"argument must be a mapping or an iterable of key-value pairs, got {type(other).__name__}")
other = dict(other)
for key in other:
self.__setitem__(key, other[key])
for key, value in kwargs.items():
self.__setitem__(key, value)
return super().update(other)


Expand Down