Skip to content

Commit 7639f90

Browse files
hynky1999Hynek Kydlicek
andauthored
Task config (#289)
* add new params to config class * clean up task/config * connect datatasert revision and filter * add tests for filtering/revision * nit * nit+1 * remove redudant check --------- Co-authored-by: Hynek Kydlicek <[email protected]>
1 parent ba0baab commit 7639f90

File tree

3 files changed

+153
-62
lines changed

3 files changed

+153
-62
lines changed

src/lighteval/tasks/lighteval_task.py

Lines changed: 66 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
import collections
2424
import inspect
2525
import random
26-
from dataclasses import asdict, dataclass
26+
from dataclasses import asdict, dataclass, field
2727
from multiprocessing import Pool
28-
from pathlib import Path
29-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
28+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
3029

30+
from datasets import DatasetDict
3131
from huggingface_hub import TextGenerationInputGrammarType
3232
from pytablewriter import MarkdownTableWriter
3333

@@ -54,7 +54,7 @@
5454
RequestType,
5555
SampleUid,
5656
)
57-
from lighteval.utils.utils import as_list, download_dataset_worker
57+
from lighteval.utils.utils import ListLike, as_list, download_dataset_worker
5858

5959

6060
if TYPE_CHECKING:
@@ -82,55 +82,58 @@ class LightevalTaskConfig:
8282
original_num_docs (int): Number of documents in the task
8383
effective_num_docs (int): Number of documents used in a specific evaluation
8484
truncated_num_docs (bool): Whether less than the total number of documents were used
85-
output_regex (str)
86-
frozen (bool)
8785
trust_dataset (bool): Whether to trust the dataset at execution or not
8886
version (int): The version of the task. Defaults to 0. Can be increased if the underlying dataset or the prompt changes.
87+
output_regex (str)
88+
frozen (bool)
8989
"""
9090

9191
name: str
92-
prompt_function: Callable # [[dict, str], Doc]
92+
prompt_function: Callable[[dict, str], Doc]
9393
hf_repo: str
9494
hf_subset: str
95-
metric: Tuple[Union[Metric, Metrics]]
96-
hf_avail_splits: Optional[Tuple[str]] = None
97-
evaluation_splits: Optional[Tuple[str]] = None
95+
metric: ListLike[Metric | Metrics]
96+
97+
# Additional hf dataset config
98+
hf_revision: Optional[str] = None
99+
hf_filter: Optional[Callable[[dict], bool]] = None
100+
hf_avail_splits: Optional[ListLike[str]] = field(default_factory=lambda: ["train", "validation", "test"])
101+
# We default to false, to reduce security issues
102+
trust_dataset: bool = False
103+
104+
# Splits
105+
evaluation_splits: ListLike[str] = field(default_factory=lambda: ["validation"])
98106
few_shots_split: Optional[str] = None
99107
few_shots_select: Optional[str] = None
108+
109+
# Generation args
100110
generation_size: Optional[int] = None
101111
generation_grammar: Optional[TextGenerationInputGrammarType] = None
102-
stop_sequence: Optional[Tuple[str]] = None
112+
stop_sequence: Optional[ListLike[str]] = None
103113
output_regex: Optional[str] = None
104114
num_samples: Optional[list[int]] = None
105115

106-
frozen: bool = False
107-
suite: Optional[Tuple[str]] = None
116+
suite: ListLike[str] = field(default_factory=lambda: ["custom"])
108117

109118
original_num_docs: int = -1
110119
effective_num_docs: int = -1
111120

112-
trust_dataset: bool = None
113-
114-
must_remove_duplicate_docs: bool = None
121+
must_remove_duplicate_docs: bool = False
115122

116123
version: int = 0
117124

118-
def __post_init__(self):
119-
if self.suite is None:
120-
self.suite = ["custom"]
121-
if self.hf_avail_splits is None:
122-
self.hf_avail_splits = ["train", "validation", "test"]
123-
if self.evaluation_splits is None:
124-
self.evaluation_splits = ["validation"]
125+
# Currently unused
126+
frozen: bool = False
125127

128+
def __post_init__(self):
126129
# If we got a Metrics enums instead of a Metric, we convert
127130
self.metric = [metric.value if isinstance(metric, Metrics) else metric for metric in self.metric]
128131

129132
# Convert list to tuple for hashing
130133
self.metric = tuple(self.metric)
131134
self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits is not None else None
132-
self.evaluation_splits = tuple(self.evaluation_splits) if self.evaluation_splits is not None else None
133-
self.suite = tuple(self.suite) if self.suite is not None else None
135+
self.evaluation_splits = tuple(self.evaluation_splits)
136+
self.suite = tuple(self.suite)
134137
self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence is not None else None
135138

136139
def print(self):
@@ -175,31 +178,27 @@ def __init__( # noqa: C901
175178
"""
176179
self.name = name
177180
self.version = cfg.version
178-
self.is_main_process = False
179181
self.cache_dir = cache_dir
180182
self._cfg = cfg
181183

182184
# Dataset info
183-
self.hf_repo = cfg.hf_repo
184-
self.hf_subset = cfg.hf_subset
185-
self.dataset_path = self.hf_repo
186-
self.dataset_config_name = self.hf_subset
187-
self.dataset = None # Delayed download
185+
self.dataset_path = cfg.hf_repo
186+
self.dataset_config_name = cfg.hf_subset
187+
self.dataset_revision = cfg.hf_revision
188+
self.dataset_filter = cfg.hf_filter
188189
self.trust_dataset = cfg.trust_dataset
190+
self.dataset: Optional[DatasetDict] = None # Delayed download
189191
hlog(f"{self.dataset_path} {self.dataset_config_name}")
190192
self._fewshot_docs = None
191193
self._docs = None
192194

193-
# Managing splits and few shot
194-
self.all_available_splits = as_list(cfg.hf_avail_splits)
195-
if cfg.evaluation_splits is None:
196-
raise ValueError(f"The evaluation split for task {self.name} is None. Please select a valid split.")
197-
198195
self.evaluation_split = as_list(cfg.evaluation_splits)
196+
197+
self.fewshot_split: list[str] | None
199198
if cfg.few_shots_split is not None:
200199
self.fewshot_split = as_list(cfg.few_shots_split)
201200
else:
202-
self.fewshot_split = as_list(self.get_first_possible_fewshot_splits())
201+
self.fewshot_split = self.get_first_possible_fewshot_splits(cfg.hf_avail_splits or [])
203202
self.fewshot_selection = cfg.few_shots_select
204203

205204
# Metrics
@@ -223,30 +222,20 @@ def __init__( # noqa: C901
223222
if "maj@" in metric_name:
224223
self.num_samples.append(int(metric_name.replace("maj@", "").split("_")[0]))
225224

226-
if not isinstance(cfg.prompt_function, Callable):
227-
raise TypeError(
228-
f"Prompt formatting function ({str(cfg.prompt_function)}) should have been passed as a callable, was {type(cfg.prompt_function)} instead."
229-
)
230225
self.formatter = cfg.prompt_function
231226

232227
self.generation_size = cfg.generation_size
233228
self.generation_grammar = cfg.generation_grammar
234229
self.stop_sequence = cfg.stop_sequence
235-
self.output_regex = cfg.output_regex
236230
self.must_remove_duplicate_docs = cfg.must_remove_duplicate_docs
237-
if self.must_remove_duplicate_docs is None:
238-
self.must_remove_duplicate_docs = False
239-
240-
# Save options
241-
self.save_queries: bool = False
242-
self.logfile_name: Optional[Path] = None
243-
self.is_main_process: bool = False
244231

245232
@property
246233
def cfg(self):
247234
return self._cfg
248235

249-
def get_first_possible_fewshot_splits(self, number_of_splits: int = 1) -> list[str]:
236+
def get_first_possible_fewshot_splits(
237+
self, available_splits: ListLike[str], number_of_splits: int = 1
238+
) -> list[str] | None:
250239
"""
251240
Parses the possible fewshot split keys in order: train, then validation
252241
keys and matches them with the available keys. Returns the first
@@ -260,7 +249,7 @@ def get_first_possible_fewshot_splits(self, number_of_splits: int = 1) -> list[s
260249
list[str]: List of the first available fewshot splits.
261250
"""
262251
# Possible few shot splits are the available splits not used for evaluation
263-
possible_fewshot_splits = [k for k in self.all_available_splits if k not in self.evaluation_split]
252+
possible_fewshot_splits = [k for k in available_splits if k not in self.evaluation_split]
264253
stored_splits = []
265254

266255
# We look at these keys in order (first the training sets, then the validation sets)
@@ -289,7 +278,13 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
289278
list[Doc]: List of documents.
290279
"""
291280
if self.dataset is None:
292-
self.dataset = download_dataset_worker((self.dataset_path, self.dataset_config_name, self.trust_dataset))
281+
self.dataset = download_dataset_worker(
282+
self.dataset_path,
283+
self.dataset_config_name,
284+
self.trust_dataset,
285+
self.dataset_filter,
286+
self.dataset_revision,
287+
)
293288
splits = as_list(splits)
294289

295290
docs = []
@@ -326,7 +321,7 @@ def fewshot_docs(self) -> list[Doc]:
326321
self._fewshot_docs = []
327322

328323
# If we have no available few shot split, the few shot data is the eval data!
329-
if self.fewshot_split in [None, [None]]:
324+
if self.fewshot_split is None:
330325
self._fewshot_docs = self._get_docs_from_split(self.evaluation_split, few_shots=True)
331326
else: # Normal case
332327
self._fewshot_docs = self._get_docs_from_split(self.fewshot_split, few_shots=True)
@@ -552,14 +547,29 @@ def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int =
552547

553548
if dataset_loading_processes <= 1:
554549
datasets = [
555-
download_dataset_worker((task.dataset_path, task.dataset_config_name, task.trust_dataset))
550+
download_dataset_worker(
551+
task.dataset_path,
552+
task.dataset_config_name,
553+
task.trust_dataset,
554+
task.dataset_filter,
555+
task.dataset_revision,
556+
)
556557
for task in tasks
557558
]
558559
else:
559560
with Pool(processes=dataset_loading_processes) as pool:
560-
datasets = pool.map(
561+
datasets = pool.starmap(
561562
download_dataset_worker,
562-
[(task.dataset_path, task.dataset_config_name, task.trust_dataset) for task in tasks],
563+
[
564+
(
565+
task.dataset_path,
566+
task.dataset_config_name,
567+
task.trust_dataset,
568+
task.dataset_filter,
569+
task.dataset_revision,
570+
)
571+
for task in tasks
572+
],
563573
)
564574

565575
for task, dataset in zip(tasks, datasets):

src/lighteval/utils/utils.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414
import os
1515
from dataclasses import asdict, dataclass, is_dataclass
16-
from typing import Any, Union
16+
from typing import Callable, TypeVar, Union
1717

1818
import numpy as np
19-
from datasets import load_dataset
19+
from datasets import DatasetDict, load_dataset
2020
from pytablewriter import MarkdownTableWriter
2121

2222

@@ -109,7 +109,14 @@ def sanitize_numpy(example_dict: dict) -> dict:
109109
return output_dict
110110

111111

112-
def as_list(item: Union[list, tuple, Any]) -> list:
112+
ListLikeTypeVar = TypeVar("ListLikeTypeVar")
113+
ListLike = list[ListLikeTypeVar] | tuple[ListLikeTypeVar, ...]
114+
115+
116+
ElementType = TypeVar("ElementType")
117+
118+
119+
def as_list(item: ListLike[ElementType] | ElementType) -> list[ElementType]:
113120
"""
114121
Convert the given item into a list.
115122
@@ -126,8 +133,10 @@ def as_list(item: Union[list, tuple, Any]) -> list:
126133
"""
127134
if isinstance(item, list):
128135
return item
136+
129137
elif isinstance(item, tuple):
130138
return list(item)
139+
131140
return [item]
132141

133142

@@ -205,21 +214,32 @@ def boolstring_to_bool(x: Union[str, bool, int]) -> Union[bool, None]:
205214
raise ValueError(f"You tried to convert {x} to a boolean but it's not possible.")
206215

207216

208-
def download_dataset_worker(args):
217+
def download_dataset_worker(
218+
dataset_path: str,
219+
dataset_config_name: str,
220+
trust_dataset: bool,
221+
dataset_filter: Callable[[dict], bool] | None = None,
222+
revision: str | None = None,
223+
) -> DatasetDict:
209224
"""
210225
Worker function to download a dataset from the HuggingFace Hub.
211226
Used for parallel dataset loading.
212227
"""
213-
dataset_path, dataset_config_name, trust_dataset = args
214228
dataset = load_dataset(
215229
path=dataset_path,
216230
name=dataset_config_name,
217231
data_dir=None,
218232
cache_dir=None,
219233
download_mode=None,
220234
trust_remote_code=trust_dataset,
235+
revision=revision,
221236
)
222-
return dataset
237+
238+
if dataset_filter is not None:
239+
dataset = dataset.filter(dataset_filter)
240+
241+
# It returns DatasetDict because we don't specify a split
242+
return dataset # type: ignore
223243

224244

225245
def safe_divide(numerator: np.ndarray, denominator: float, default_value: float = 0.0) -> np.ndarray:

tests/tasks/test_lighteval_task.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig
24+
25+
26+
def dummy_prompt_function(item, task_name):
27+
return item["text"]
28+
29+
30+
def test_revision_check():
31+
# Test with a different revision
32+
cfg_with_revision = LightevalTaskConfig(
33+
name="test_task_revision",
34+
prompt_function=dummy_prompt_function,
35+
hf_repo="lighteval-tests-datasets/dataset-test-1",
36+
hf_subset="default",
37+
evaluation_splits=["train"],
38+
metric=[],
39+
hf_revision="25175defadfde48b131b7cd7573ad6f59f868306",
40+
)
41+
task_with_revision = LightevalTask("test_task_revision", cfg_with_revision)
42+
assert task_with_revision.eval_docs() == ["hi", "how are you?"]
43+
44+
45+
def test_dataset_filter():
46+
# Setup
47+
48+
cfg = LightevalTaskConfig(
49+
name="test_task",
50+
prompt_function=dummy_prompt_function,
51+
hf_repo="lighteval-tests-datasets/dataset-test-1",
52+
hf_subset="default",
53+
hf_filter=lambda x: x["text"] == "hi",
54+
metric=[],
55+
evaluation_splits=["train"],
56+
)
57+
task = LightevalTask("test_task", cfg)
58+
59+
filtered_docs = task.eval_docs()
60+
assert len(filtered_docs) == 1
61+
assert filtered_docs[0] == "hi"

0 commit comments

Comments
 (0)