23
23
import collections
24
24
import inspect
25
25
import random
26
- from dataclasses import asdict , dataclass
26
+ from dataclasses import asdict , dataclass , field
27
27
from multiprocessing import Pool
28
- from pathlib import Path
29
- from typing import TYPE_CHECKING , Callable , Dict , List , Optional , Tuple , Union
28
+ from typing import TYPE_CHECKING , Callable , Dict , List , Optional , Tuple
30
29
30
+ from datasets import DatasetDict
31
31
from huggingface_hub import TextGenerationInputGrammarType
32
32
from pytablewriter import MarkdownTableWriter
33
33
54
54
RequestType ,
55
55
SampleUid ,
56
56
)
57
- from lighteval .utils .utils import as_list , download_dataset_worker
57
+ from lighteval .utils .utils import ListLike , as_list , download_dataset_worker
58
58
59
59
60
60
if TYPE_CHECKING :
@@ -82,55 +82,58 @@ class LightevalTaskConfig:
82
82
original_num_docs (int): Number of documents in the task
83
83
effective_num_docs (int): Number of documents used in a specific evaluation
84
84
truncated_num_docs (bool): Whether less than the total number of documents were used
85
- output_regex (str)
86
- frozen (bool)
87
85
trust_dataset (bool): Whether to trust the dataset at execution or not
88
86
version (int): The version of the task. Defaults to 0. Can be increased if the underlying dataset or the prompt changes.
87
+ output_regex (str)
88
+ frozen (bool)
89
89
"""
90
90
91
91
name : str
92
- prompt_function : Callable # [[dict, str], Doc]
92
+ prompt_function : Callable [[dict , str ], Doc ]
93
93
hf_repo : str
94
94
hf_subset : str
95
- metric : Tuple [Union [Metric , Metrics ]]
96
- hf_avail_splits : Optional [Tuple [str ]] = None
97
- evaluation_splits : Optional [Tuple [str ]] = None
95
+ metric : ListLike [Metric | Metrics ]
96
+
97
+ # Additional hf dataset config
98
+ hf_revision : Optional [str ] = None
99
+ hf_filter : Optional [Callable [[dict ], bool ]] = None
100
+ hf_avail_splits : Optional [ListLike [str ]] = field (default_factory = lambda : ["train" , "validation" , "test" ])
101
+ # We default to false, to reduce security issues
102
+ trust_dataset : bool = False
103
+
104
+ # Splits
105
+ evaluation_splits : ListLike [str ] = field (default_factory = lambda : ["validation" ])
98
106
few_shots_split : Optional [str ] = None
99
107
few_shots_select : Optional [str ] = None
108
+
109
+ # Generation args
100
110
generation_size : Optional [int ] = None
101
111
generation_grammar : Optional [TextGenerationInputGrammarType ] = None
102
- stop_sequence : Optional [Tuple [str ]] = None
112
+ stop_sequence : Optional [ListLike [str ]] = None
103
113
output_regex : Optional [str ] = None
104
114
num_samples : Optional [list [int ]] = None
105
115
106
- frozen : bool = False
107
- suite : Optional [Tuple [str ]] = None
116
+ suite : ListLike [str ] = field (default_factory = lambda : ["custom" ])
108
117
109
118
original_num_docs : int = - 1
110
119
effective_num_docs : int = - 1
111
120
112
- trust_dataset : bool = None
113
-
114
- must_remove_duplicate_docs : bool = None
121
+ must_remove_duplicate_docs : bool = False
115
122
116
123
version : int = 0
117
124
118
- def __post_init__ (self ):
119
- if self .suite is None :
120
- self .suite = ["custom" ]
121
- if self .hf_avail_splits is None :
122
- self .hf_avail_splits = ["train" , "validation" , "test" ]
123
- if self .evaluation_splits is None :
124
- self .evaluation_splits = ["validation" ]
125
+ # Currently unused
126
+ frozen : bool = False
125
127
128
+ def __post_init__ (self ):
126
129
# If we got a Metrics enums instead of a Metric, we convert
127
130
self .metric = [metric .value if isinstance (metric , Metrics ) else metric for metric in self .metric ]
128
131
129
132
# Convert list to tuple for hashing
130
133
self .metric = tuple (self .metric )
131
134
self .hf_avail_splits = tuple (self .hf_avail_splits ) if self .hf_avail_splits is not None else None
132
- self .evaluation_splits = tuple (self .evaluation_splits ) if self . evaluation_splits is not None else None
133
- self .suite = tuple (self .suite ) if self . suite is not None else None
135
+ self .evaluation_splits = tuple (self .evaluation_splits )
136
+ self .suite = tuple (self .suite )
134
137
self .stop_sequence = tuple (self .stop_sequence ) if self .stop_sequence is not None else None
135
138
136
139
def print (self ):
@@ -175,31 +178,27 @@ def __init__( # noqa: C901
175
178
"""
176
179
self .name = name
177
180
self .version = cfg .version
178
- self .is_main_process = False
179
181
self .cache_dir = cache_dir
180
182
self ._cfg = cfg
181
183
182
184
# Dataset info
183
- self .hf_repo = cfg .hf_repo
184
- self .hf_subset = cfg .hf_subset
185
- self .dataset_path = self .hf_repo
186
- self .dataset_config_name = self .hf_subset
187
- self .dataset = None # Delayed download
185
+ self .dataset_path = cfg .hf_repo
186
+ self .dataset_config_name = cfg .hf_subset
187
+ self .dataset_revision = cfg .hf_revision
188
+ self .dataset_filter = cfg .hf_filter
188
189
self .trust_dataset = cfg .trust_dataset
190
+ self .dataset : Optional [DatasetDict ] = None # Delayed download
189
191
hlog (f"{ self .dataset_path } { self .dataset_config_name } " )
190
192
self ._fewshot_docs = None
191
193
self ._docs = None
192
194
193
- # Managing splits and few shot
194
- self .all_available_splits = as_list (cfg .hf_avail_splits )
195
- if cfg .evaluation_splits is None :
196
- raise ValueError (f"The evaluation split for task { self .name } is None. Please select a valid split." )
197
-
198
195
self .evaluation_split = as_list (cfg .evaluation_splits )
196
+
197
+ self .fewshot_split : list [str ] | None
199
198
if cfg .few_shots_split is not None :
200
199
self .fewshot_split = as_list (cfg .few_shots_split )
201
200
else :
202
- self .fewshot_split = as_list ( self .get_first_possible_fewshot_splits () )
201
+ self .fewshot_split = self .get_first_possible_fewshot_splits (cfg . hf_avail_splits or [] )
203
202
self .fewshot_selection = cfg .few_shots_select
204
203
205
204
# Metrics
@@ -223,30 +222,20 @@ def __init__( # noqa: C901
223
222
if "maj@" in metric_name :
224
223
self .num_samples .append (int (metric_name .replace ("maj@" , "" ).split ("_" )[0 ]))
225
224
226
- if not isinstance (cfg .prompt_function , Callable ):
227
- raise TypeError (
228
- f"Prompt formatting function ({ str (cfg .prompt_function )} ) should have been passed as a callable, was { type (cfg .prompt_function )} instead."
229
- )
230
225
self .formatter = cfg .prompt_function
231
226
232
227
self .generation_size = cfg .generation_size
233
228
self .generation_grammar = cfg .generation_grammar
234
229
self .stop_sequence = cfg .stop_sequence
235
- self .output_regex = cfg .output_regex
236
230
self .must_remove_duplicate_docs = cfg .must_remove_duplicate_docs
237
- if self .must_remove_duplicate_docs is None :
238
- self .must_remove_duplicate_docs = False
239
-
240
- # Save options
241
- self .save_queries : bool = False
242
- self .logfile_name : Optional [Path ] = None
243
- self .is_main_process : bool = False
244
231
245
232
@property
246
233
def cfg (self ):
247
234
return self ._cfg
248
235
249
- def get_first_possible_fewshot_splits (self , number_of_splits : int = 1 ) -> list [str ]:
236
+ def get_first_possible_fewshot_splits (
237
+ self , available_splits : ListLike [str ], number_of_splits : int = 1
238
+ ) -> list [str ] | None :
250
239
"""
251
240
Parses the possible fewshot split keys in order: train, then validation
252
241
keys and matches them with the available keys. Returns the first
@@ -260,7 +249,7 @@ def get_first_possible_fewshot_splits(self, number_of_splits: int = 1) -> list[s
260
249
list[str]: List of the first available fewshot splits.
261
250
"""
262
251
# Possible few shot splits are the available splits not used for evaluation
263
- possible_fewshot_splits = [k for k in self . all_available_splits if k not in self .evaluation_split ]
252
+ possible_fewshot_splits = [k for k in available_splits if k not in self .evaluation_split ]
264
253
stored_splits = []
265
254
266
255
# We look at these keys in order (first the training sets, then the validation sets)
@@ -289,7 +278,13 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
289
278
list[Doc]: List of documents.
290
279
"""
291
280
if self .dataset is None :
292
- self .dataset = download_dataset_worker ((self .dataset_path , self .dataset_config_name , self .trust_dataset ))
281
+ self .dataset = download_dataset_worker (
282
+ self .dataset_path ,
283
+ self .dataset_config_name ,
284
+ self .trust_dataset ,
285
+ self .dataset_filter ,
286
+ self .dataset_revision ,
287
+ )
293
288
splits = as_list (splits )
294
289
295
290
docs = []
@@ -326,7 +321,7 @@ def fewshot_docs(self) -> list[Doc]:
326
321
self ._fewshot_docs = []
327
322
328
323
# If we have no available few shot split, the few shot data is the eval data!
329
- if self .fewshot_split in [ None , [ None ]] :
324
+ if self .fewshot_split is None :
330
325
self ._fewshot_docs = self ._get_docs_from_split (self .evaluation_split , few_shots = True )
331
326
else : # Normal case
332
327
self ._fewshot_docs = self ._get_docs_from_split (self .fewshot_split , few_shots = True )
@@ -552,14 +547,29 @@ def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int =
552
547
553
548
if dataset_loading_processes <= 1 :
554
549
datasets = [
555
- download_dataset_worker ((task .dataset_path , task .dataset_config_name , task .trust_dataset ))
550
+ download_dataset_worker (
551
+ task .dataset_path ,
552
+ task .dataset_config_name ,
553
+ task .trust_dataset ,
554
+ task .dataset_filter ,
555
+ task .dataset_revision ,
556
+ )
556
557
for task in tasks
557
558
]
558
559
else :
559
560
with Pool (processes = dataset_loading_processes ) as pool :
560
- datasets = pool .map (
561
+ datasets = pool .starmap (
561
562
download_dataset_worker ,
562
- [(task .dataset_path , task .dataset_config_name , task .trust_dataset ) for task in tasks ],
563
+ [
564
+ (
565
+ task .dataset_path ,
566
+ task .dataset_config_name ,
567
+ task .trust_dataset ,
568
+ task .dataset_filter ,
569
+ task .dataset_revision ,
570
+ )
571
+ for task in tasks
572
+ ],
563
573
)
564
574
565
575
for task , dataset in zip (tasks , datasets ):
0 commit comments