Skip to content

Commit df4d57b

Browse files
Merge branch 'main' into deciding-target-for-fewshot-sorting
2 parents 8d0307e + 194e742 commit df4d57b

File tree

4 files changed

+40
-21
lines changed

4 files changed

+40
-21
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
include src/lighteval/tasks/tasks_table.jsonl
2+
include src/lighteval/metrics/*.jsonl

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ accelerate launch --multi_gpu --num_processes=<num_gpus> -m \
101101
--output_dir output_dir
102102
```
103103

104-
Here, `--tasks` refers to either a _comma-separated_ list of supported tasks from the [metadata table](src/lighteval/tasks/tasks_table.jsonl) in the format:
104+
Here, `--tasks` refers to either a _comma-separated_ list of supported tasks from the [tasks_list](examples/tasks/all_tasks.txt) in the format:
105+
Tasks details can also be found in the [file implementing them](src/lighteval/tasks/default_tasks.py).
105106

106107
```
107108
suite|task|num_few_shot|{0 or 1 to automatically reduce `num_few_shot` if prompt is too long}
@@ -113,7 +114,7 @@ or a file path like [`examples/tasks/recommended_set.txt`](./examples/tasks/reco
113114
accelerate launch --multi_gpu --num_processes=8 -m \
114115
lighteval accelerate \
115116
--model_args "pretrained=gpt2" \
116-
--tasks "lighteval|truthfulqa:mc|0|0" \
117+
--tasks "leaderboard|truthfulqa:mc|0|0" \
117118
--override_batch_size 1 \
118119
--output_dir="./evals/"
119120
```

src/lighteval/metrics/metrics_sample.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def __init__(
324324
normalize_gold: callable = None,
325325
normalize_pred: callable = None,
326326
aggregation_function: callable = None,
327+
tokenizer: object = None,
327328
):
328329
"""A ROUGE wrapper method. Relies on `rouge_scorer`.
329330
@@ -338,6 +339,8 @@ def __init__(
338339
Defaults to None if no normalization is applied.
339340
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
340341
Defaults to None if no normalization is applied.
342+
tokenizer (object, optional): An object with `tokenize` method to be used by rouge scorer. If None, rouge-scorer's
343+
default tokenizer will be used.
341344
"""
342345
if aggregation_function and bootstrap:
343346
hlog_warn("Can't use both bootstrapping and an aggregation function in Rouge. Keeping bootstrap.")
@@ -350,7 +353,7 @@ def __init__(
350353
raise ValueError(
351354
f"Rouge was initialised with method {methods}, which is not in {','.join(self.ALLOWED_ROUGE_METHODS)}"
352355
)
353-
self.scorer = rouge_scorer.RougeScorer([methods])
356+
self.scorer = rouge_scorer.RougeScorer([methods], tokenizer=tokenizer)
354357
self.multiple_golds = multiple_golds
355358
self.bootstrap = bootstrap
356359
self.normalize_gold = normalize_gold
@@ -416,8 +419,18 @@ def __init__(
416419
normalize_gold: callable = None,
417420
normalize_pred: callable = None,
418421
):
419-
"""A BERT scorer class. Relies on some called extracted from `bert-score`. By default, will use the
420-
`microsoft/deberta-large-mnli` as scorer
422+
r"""A BERT scorer class. Relies on some called extracted from `bert-score`. By default, will use the
423+
`microsoft/deberta-large-mnli` as scorer. For each tokenized (pred, target) pair, it computes Precision,
424+
Recall and F1 as following:
425+
426+
Precision = \sum_{t=1}^{len(pred)} \div{max(Cos.Sim.(pred_t, target))}{IDF(pred_t)}
427+
428+
Recall = \sum_{t=1}^{len(target)} \div{max(Cos.Sim.(target_t, pred))}{IDF(target_t)}
429+
430+
F1 = \div{Precision * Recall}{Precision + Recall}
431+
432+
in which `Cos.Sim.` is the Cosine Similarity metric and `IDF(.)` represents the Inverse Document
433+
Frequency of its input token. It defaults to 1 for all tokens and 0 for EOS and SEP tokens.
421434
422435
Args:
423436
normalize_gold (callable, optional): Function to use to normalize the reference strings.
@@ -563,19 +576,19 @@ def __init__(
563576
self.strip_prediction = strip_prediction
564577
self.sample_aggregations = {"longest_common_prefix_length": max, "edit_distance": min, "edit_similarity": max}
565578

566-
def compute(self, gold: list[str], predictions: list[str], **kwargs) -> dict:
579+
def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict:
567580
"""Computes all the requested metrics on the golds and prediction.
568581
569582
Args:
570-
gold (list[str]): A list of possible golds. If it contains more than one item, only the first one is kept.
583+
golds (list[str]): A list of possible golds. If it contains more than one item, only the first one is kept.
571584
predictions (list[str]): Predicted strings.
572585
573586
Returns:
574587
dict: The different scores computed
575588
"""
576-
if len(gold) > 0:
589+
if len(golds) > 1:
577590
hlog_warn("Provided more than one gold to compute a string distance metric. Just using the first one.")
578-
reference = gold[0]
591+
reference = golds[0]
579592

580593
result = {m: [] for m in self.metric_types}
581594
for sequence in predictions:

src/lighteval/models/utils.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from transformers import AutoConfig
3030

3131

32-
def _get_dtype(dtype: Union[str, torch.dtype], config: Optional[AutoConfig] = None) -> torch.dtype:
32+
def _get_dtype(dtype: Union[str, torch.dtype, None], config: Optional[AutoConfig] = None) -> Optional[torch.dtype]:
3333
"""
3434
Get the torch dtype based on the input arguments.
3535
@@ -41,17 +41,21 @@ def _get_dtype(dtype: Union[str, torch.dtype], config: Optional[AutoConfig] = No
4141
torch.dtype: The torch dtype based on the input arguments.
4242
"""
4343

44-
if config is not None: # For quantized models
45-
if hasattr(config, "quantization_config"):
46-
_torch_dtype = None # must be inferred
47-
else:
48-
_torch_dtype = config.torch_dtype
49-
elif isinstance(dtype, str) and dtype not in ["auto", "4bit", "8bit"]:
50-
# Convert `str` args torch dtype: `float16` -> `torch.float16`
51-
_torch_dtype = getattr(torch, dtype)
52-
else:
53-
_torch_dtype = dtype
54-
return _torch_dtype
44+
if config is not None and hasattr(config, "quantization_config"):
45+
# must be infered
46+
return None
47+
48+
if dtype is not None:
49+
if isinstance(dtype, str) and dtype not in ["auto", "4bit", "8bit"]:
50+
# Convert `str` args torch dtype: `float16` -> `torch.float16`
51+
return getattr(torch, dtype)
52+
elif isinstance(dtype, torch.dtype):
53+
return dtype
54+
55+
if config is not None:
56+
return config.torch_dtype
57+
58+
return None
5559

5660

5761
def _simplify_name(name_or_path: str) -> str:

0 commit comments

Comments
 (0)