Skip to content

Commit d57572e

Browse files
committed
Add temporary fix for the config _type_ issue
Signed-off-by: Fabrice Normandin <[email protected]>
1 parent be462e8 commit d57572e

File tree

5 files changed

+216
-34
lines changed

5 files changed

+216
-34
lines changed

simple_parsing/helpers/subgroups.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import _MISSING_TYPE, MISSING
77
from enum import Enum
88
from logging import getLogger as get_logger
9-
from typing import Any, Callable, TypeVar, Union, overload
9+
from typing import Any, Callable, Mapping, TypeVar, Union
1010

1111
from typing_extensions import TypeAlias
1212

@@ -21,7 +21,7 @@
2121

2222

2323
def subgroups(
24-
subgroups: dict[Key, type[DC] | functools.partial[DC]],
24+
subgroups: Mapping[Key, type[DC] | functools.partial[DC]],
2525
*args,
2626
default: Key | _MISSING_TYPE = MISSING,
2727
default_factory: type[DC] | functools.partial[DC] | _MISSING_TYPE = MISSING,
@@ -59,8 +59,8 @@ def subgroups(
5959
"dataclass."
6060
)
6161
if default not in subgroups.values():
62-
# TODO: (@lebrice): Do we really need to enforce this? What is the reasoning behind this
63-
# restriction again?
62+
# NOTE: The reason we enforce this is perhaps artificial, but it's because the way we
63+
# implement subgroups requires us to know the key that is selected in the dict.
6464
raise ValueError(f"Default value {default} needs to be a value in the subgroups dict.")
6565
elif default is not MISSING and default not in subgroups.keys():
6666
raise ValueError("default must be a key in the subgroups dict!")
@@ -162,7 +162,14 @@ def subgroups(
162162

163163
from .fields import choice
164164

165-
return choice(choices, *args, default=default, default_factory=default_factory, metadata=metadata, **kwargs) # type: ignore
165+
return choice(
166+
choices,
167+
*args,
168+
default=default,
169+
default_factory=default_factory,
170+
metadata=metadata,
171+
**kwargs,
172+
) # type: ignore
166173

167174

168175
def _get_dataclass_type_from_callable(
@@ -179,7 +186,8 @@ def _get_dataclass_type_from_callable(
179186
return dataclass_fn.func
180187
# partial to a function that should return a dataclass. Hopefully it has a return type
181188
# annotation, otherwise we'd have to call the function just to know the return type!
182-
# NOTE: recurse here, so it also works with `partial(partial(...))` and `partial(some_function)`
189+
# NOTE: recurse here, so it also works with `partial(partial(...))` and
190+
# `partial(some_function)`
183191
return _get_dataclass_type_from_callable(
184192
dataclass_fn=dataclass_fn.func, caller_frame=caller_frame
185193
)
@@ -194,7 +202,6 @@ def _get_dataclass_type_from_callable(
194202
# Recurse, so this also works with partial(partial(...)) (idk why you'd do that though.)
195203

196204
if isinstance(signature.return_annotation, str):
197-
198205
dataclass_fn_type = signature.return_annotation
199206
if caller_frame is not None:
200207
# Travel up until we find the right frame where the subgroup is defined.
@@ -212,7 +219,8 @@ def _get_dataclass_type_from_callable(
212219
caller_globals = caller_frame.f_globals
213220

214221
try:
215-
# NOTE: This doesn't seem to be very often different than just calling `get_type_hints`
222+
# NOTE: This doesn't seem to be very often different than just calling
223+
# `get_type_hints`
216224
type_hints = typing.get_type_hints(
217225
dataclass_fn, globalns=caller_globals, localns=caller_locals
218226
)
@@ -223,8 +231,8 @@ def _get_dataclass_type_from_callable(
223231
type_hints = typing.get_type_hints(dataclass_fn)
224232
dataclass_fn_type = type_hints["return"]
225233

226-
# Recursing here would be a bit extra, let's be real. Might be good enough to just assume that
227-
# the return annotation needs to be a dataclass.
234+
# Recursing here would be a bit extra, let's be real. Might be good enough to just assume
235+
# that the return annotation needs to be a dataclass.
228236
# return _get_dataclass_type_from_callable(dataclass_fn_type, caller_frame=caller_frame)
229237
assert is_dataclass_type(dataclass_fn_type)
230238
return dataclass_fn_type

simple_parsing/parsing.py

Lines changed: 171 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import argparse
88
import dataclasses
99
import functools
10+
import inspect
1011
import itertools
1112
import shlex
1213
import sys
@@ -15,18 +16,23 @@
1516
from collections import defaultdict
1617
from logging import getLogger
1718
from pathlib import Path
18-
from typing import Any, Callable, Sequence, Type, overload
19-
19+
from typing import Any, Callable, Mapping, Sequence, Type, overload
20+
from typing_extensions import TypeGuard
21+
import warnings
2022
from simple_parsing.helpers.subgroups import SubgroupKey
23+
from simple_parsing.replace import SUBGROUP_KEY_FLAG
2124
from simple_parsing.wrappers.dataclass_wrapper import DataclassWrapperType
2225

2326
from . import utils
2427
from .conflicts import ConflictResolution, ConflictResolver
2528
from .help_formatter import SimpleHelpFormatter
26-
from .helpers.serialization.serializable import read_file
29+
from .helpers.serialization.serializable import DC_TYPE_KEY, read_file
2730
from .utils import (
31+
K,
32+
V,
2833
Dataclass,
2934
DataclassT,
35+
PossiblyNestedDict,
3036
dict_union,
3137
is_dataclass_instance,
3238
is_dataclass_type,
@@ -593,7 +599,7 @@ def _resolve_subgroups(
593599
594600
This modifies the wrappers in-place, by possibly adding children to the wrappers in the
595601
list.
596-
Returns a list with the modified wrappers.
602+
Returns a list with the (now modified) wrappers.
597603
598604
Each round does the following:
599605
1. Resolve any conflicts using the conflict resolver. Two subgroups at the same nesting
@@ -618,13 +624,7 @@ def _resolve_subgroups(
618624
# times.
619625
subgroup_choice_parser = argparse.ArgumentParser(
620626
add_help=False,
621-
# conflict_resolution=self.conflict_resolution,
622-
# add_option_string_dash_variants=self.add_option_string_dash_variants,
623-
# argument_generation_mode=self.argument_generation_mode,
624-
# nested_mode=self.nested_mode,
625627
formatter_class=self.formatter_class,
626-
# add_config_path_arg=self.add_config_path_arg,
627-
# config_path=self.config_path,
628628
# NOTE: We disallow abbreviations for subgroups for now. This prevents potential issues
629629
# for example if you have —a_or_b and A has a field —a then it will error out if you
630630
# pass —a=1 because 1 isn’t a choice for the a_or_b argument (because --a matches it
@@ -644,10 +644,27 @@ def _resolve_subgroups(
644644
flags = subgroup_field.option_strings
645645
argument_options = subgroup_field.arg_options
646646

647+
# Sanity checks:
647648
if subgroup_field.subgroup_default is dataclasses.MISSING:
648649
assert argument_options["required"]
650+
elif isinstance(argument_options["default"], dict):
651+
# BUG #276: The default here is a dict because it came from a config file.
652+
# Here we want the subgroup field to have a 'str' default, because we just want
653+
# to be able to choose between the subgroup names.
654+
_default = argument_options["default"]
655+
_default_key = _infer_subgroup_key_to_use_from_config(
656+
default_in_config=_default,
657+
# subgroup_default=subgroup_field.subgroup_default,
658+
subgroup_choices=subgroup_field.subgroup_choices,
659+
)
660+
# We'd like this field to (at least temporarily) have a different default
661+
# value that is the subgroup key instead of the dictionary.
662+
argument_options["default"] = _default_key
663+
649664
else:
650-
assert argument_options["default"] is subgroup_field.subgroup_default
665+
assert (
666+
argument_options["default"] is subgroup_field.subgroup_default
667+
), argument_options["default"]
651668
assert not is_dataclass_instance(argument_options["default"])
652669

653670
# TODO: Do we really need to care about this "SUPPRESS" stuff here?
@@ -1177,3 +1194,146 @@ def _create_dataclass_instance(
11771194
return None
11781195
logger.debug(f"Calling constructor: {constructor}(**{constructor_args})")
11791196
return constructor(**constructor_args)
1197+
1198+
1199+
def _has_values_of_type(
1200+
mapping: Mapping[K, Any], value_type: type[V] | tuple[type[V], ...]
1201+
) -> TypeGuard[Mapping[K, V]]:
1202+
# Utility functions used to narrow the type of dictionaries.
1203+
return all(isinstance(v, value_type) for v in mapping.values())
1204+
1205+
1206+
def _has_keys_of_type(
1207+
mapping: Mapping[Any, V], key_type: type[K] | tuple[type[K], ...]
1208+
) -> TypeGuard[Mapping[K, V]]:
1209+
# Utility functions used to narrow the type of dictionaries.
1210+
return all(isinstance(k, key_type) for k in mapping.keys())
1211+
1212+
1213+
def _has_items_of_type(
1214+
mapping: Mapping[Any, Any],
1215+
item_type: tuple[type[K] | tuple[type[K], ...], type[V] | tuple[type[V], ...]],
1216+
) -> TypeGuard[Mapping[K, V]]:
1217+
# Utility functions used to narrow the type of a dictionary or mapping.
1218+
key_type, value_type = item_type
1219+
return _has_keys_of_type(mapping, key_type) and _has_values_of_type(mapping, value_type)
1220+
1221+
1222+
def _infer_subgroup_key_to_use_from_config(
1223+
default_in_config: dict[str, Any],
1224+
# subgroup_default: Hashable,
1225+
subgroup_choices: Mapping[SubgroupKey, type[Dataclass] | functools.partial[Dataclass]],
1226+
) -> SubgroupKey:
1227+
config_default = default_in_config
1228+
1229+
if SUBGROUP_KEY_FLAG in default_in_config:
1230+
return default_in_config[SUBGROUP_KEY_FLAG]
1231+
1232+
for subgroup_key, subgroup_value in subgroup_choices.items():
1233+
if default_in_config == subgroup_value:
1234+
return subgroup_key
1235+
1236+
assert (
1237+
DC_TYPE_KEY in config_default
1238+
), f"FIXME: assuming that the {DC_TYPE_KEY} is in the config dict."
1239+
_default_type_name: str = config_default[DC_TYPE_KEY]
1240+
1241+
if _has_values_of_type(subgroup_choices, type) and all(
1242+
dataclasses.is_dataclass(subgroup_option) for subgroup_option in subgroup_choices.values()
1243+
):
1244+
# Simpler case: All the subgroup options are dataclass types. We just get the key that
1245+
# matches the type that was saved in the config dict.
1246+
subgroup_keys_with_value_matching_config_default_type: list[SubgroupKey] = [
1247+
k
1248+
for k, v in subgroup_choices.items()
1249+
if (isinstance(v, type) and f"{v.__module__}.{v.__qualname__}" == _default_type_name)
1250+
]
1251+
# NOTE: There could be duplicates I guess? Something like `subgroups({"a": A, "aa": A})`
1252+
assert len(subgroup_keys_with_value_matching_config_default_type) >= 1
1253+
return subgroup_keys_with_value_matching_config_default_type[0]
1254+
1255+
# IDEA: Try to find the best subgroup key to use, based on the number of matching constructor
1256+
# arguments between the default in the config and the defaults for each subgroup.
1257+
constructor_args_in_each_subgroup = {
1258+
key: _default_constructor_argument_values(subgroup_value)
1259+
for key, subgroup_value in subgroup_choices.items()
1260+
}
1261+
n_matching_values = {
1262+
k: _num_matching_values(config_default, constructor_args_in_subgroup_value)
1263+
for k, constructor_args_in_subgroup_value in constructor_args_in_each_subgroup.items()
1264+
}
1265+
closest_subgroups_first = sorted(
1266+
subgroup_choices.keys(),
1267+
key=n_matching_values.__getitem__,
1268+
reverse=True,
1269+
)
1270+
warnings.warn(
1271+
# TODO: Return the dataclass type instead, and be done with it!
1272+
RuntimeWarning(
1273+
f"TODO: The config file contains a default value for a subgroup that isn't in the "
1274+
f"dict of subgroup options. Because of how subgroups are currently implemented, we "
1275+
f"need to find the key in the subgroup choice dict ({subgroup_choices}) that most "
1276+
f"closely matches the value {config_default}."
1277+
f"The current implementation tries to use the dataclass type of this closest match "
1278+
f"to parse the additional values from the command-line. "
1279+
f"{default_in_config}. Consider adding the "
1280+
f"{SUBGROUP_KEY_FLAG}: <key of the subgroup to use>"
1281+
)
1282+
)
1283+
return closest_subgroups_first[0]
1284+
return closest_subgroups_first[0]
1285+
1286+
sorted(
1287+
[k for k, v in subgroup_choices.items()],
1288+
key=_num_matching_values,
1289+
reversed=True,
1290+
)
1291+
# _default_values = copy.deepcopy(config_default)
1292+
# _default_values.pop(DC_TYPE_KEY)
1293+
1294+
# default_constructor_args_for_each_subgroup = {
1295+
# k: _default_constructor_argument_values(dc_type) if dataclasses.is_dataclass(dc_type)
1296+
# }
1297+
1298+
1299+
def _default_constructor_argument_values(
1300+
some_dataclass_type: type[Dataclass] | functools.partial[Dataclass],
1301+
) -> PossiblyNestedDict[str, Any]:
1302+
result = {}
1303+
if isinstance(some_dataclass_type, functools.partial) and is_dataclass_type(
1304+
some_dataclass_type.func
1305+
):
1306+
constructor_arguments_from_classdef = _default_constructor_argument_values(
1307+
some_dataclass_type.func
1308+
)
1309+
# TODO: will probably raise an error!
1310+
constructor_arguments_from_partial = (
1311+
inspect.signature(some_dataclass_type.func)
1312+
.bind_partial(*some_dataclass_type.args, **some_dataclass_type.keywords)
1313+
.arguments
1314+
)
1315+
constructor_arguments_from_classdef.update(constructor_arguments_from_partial)
1316+
return constructor_arguments_from_classdef
1317+
1318+
assert is_dataclass_type(some_dataclass_type)
1319+
for field in dataclasses.fields(some_dataclass_type):
1320+
key = field.name
1321+
if field.default is not dataclasses.MISSING:
1322+
result[key] = field.default
1323+
elif is_dataclass_type(field.type) or (
1324+
isinstance(field.default_factory, functools.partial)
1325+
and dataclasses.is_dataclass(field.default_factory.func)
1326+
):
1327+
result[key] = _default_constructor_argument_values(field.type)
1328+
return result
1329+
1330+
1331+
def _num_matching_values(subgroup_default: dict[str, Any], subgroup_choice: dict[str, Any]) -> int:
1332+
"""Returns the number of matching entries in the subgroup dict w/ the default from the
1333+
config."""
1334+
return sum(
1335+
_num_matching_values(default_v, subgroup_choice[k])
1336+
if isinstance(subgroup_choice.get(k), dict) and isinstance(default_v, dict)
1337+
else int(subgroup_choice.get(k) == default_v)
1338+
for k, default_v in subgroup_default.items()
1339+
)

simple_parsing/replace.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,8 @@ def replace(obj: DataclassT, changes_dict: dict[str, Any] | None = None, **chang
112112
def replace_subgroups(
113113
obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None
114114
) -> DataclassT:
115-
"""
116-
This function replaces the dataclass of subgroups, union, and optional union.
117-
The `selections` dict can be in flat format or in nested format.
115+
"""This function replaces the dataclass of subgroups, union, and optional union. The
116+
`selections` dict can be in flat format or in nested format.
118117
119118
The values of selections can be `Key` of subgroups, dataclass type, and dataclass instance.
120119
"""
@@ -181,12 +180,17 @@ def replace_subgroups(
181180
return dataclasses.replace(obj, **replace_kwargs)
182181

183182

183+
SUBGROUP_KEY_FLAG = "__key__"
184+
185+
184186
def _unflatten_selection_dict(
185-
flattened: Mapping[str, V], keyword: str = "__key__", sep: str = ".", recursive: bool = True
187+
flattened: Mapping[str, V],
188+
keyword: str = SUBGROUP_KEY_FLAG,
189+
sep: str = ".",
190+
recursive: bool = True,
186191
) -> PossiblyNestedDict[str, V]:
187-
"""
188-
This function convert a flattened dict into a nested dict
189-
and it inserts the `keyword` as the selection into the nested dict.
192+
"""This function convert a flattened dict into a nested dict and it inserts the `keyword` as
193+
the selection into the nested dict.
190194
191195
>>> _unflatten_selection_dict({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'})
192196
{'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}}

simple_parsing/wrappers/field_wrapper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -968,10 +968,12 @@ def subgroup_choices(self) -> dict[Hashable, Callable[[], Dataclass] | Dataclass
968968
return self.field.metadata["subgroups"]
969969

970970
@property
971-
def subgroup_default(self) -> Hashable | Literal[dataclasses.MISSING] | None:
971+
def subgroup_default(self) -> Hashable | Literal[dataclasses.MISSING]:
972972
if not self.is_subgroup:
973973
raise RuntimeError(f"Field {self.field} doesn't have subgroups! ")
974-
return self.field.metadata.get("subgroup_default")
974+
subgroup_default_key = self.field.metadata.get("subgroup_default")
975+
assert subgroup_default_key is not None
976+
return subgroup_default_key
975977

976978
@property
977979
def type_arguments(self) -> tuple[type, ...] | None:

0 commit comments

Comments
 (0)