7
7
import argparse
8
8
import dataclasses
9
9
import functools
10
+ import inspect
10
11
import itertools
11
12
import shlex
12
13
import sys
15
16
from collections import defaultdict
16
17
from logging import getLogger
17
18
from pathlib import Path
18
- from typing import Any , Callable , Sequence , Type , overload
19
-
19
+ from typing import Any , Callable , Mapping , Sequence , Type , overload
20
+ from typing_extensions import TypeGuard
21
+ import warnings
20
22
from simple_parsing .helpers .subgroups import SubgroupKey
23
+ from simple_parsing .replace import SUBGROUP_KEY_FLAG
21
24
from simple_parsing .wrappers .dataclass_wrapper import DataclassWrapperType
22
25
23
26
from . import utils
24
27
from .conflicts import ConflictResolution , ConflictResolver
25
28
from .help_formatter import SimpleHelpFormatter
26
- from .helpers .serialization .serializable import read_file
29
+ from .helpers .serialization .serializable import DC_TYPE_KEY , read_file
27
30
from .utils import (
31
+ K ,
32
+ V ,
28
33
Dataclass ,
29
34
DataclassT ,
35
+ PossiblyNestedDict ,
30
36
dict_union ,
31
37
is_dataclass_instance ,
32
38
is_dataclass_type ,
@@ -593,7 +599,7 @@ def _resolve_subgroups(
593
599
594
600
This modifies the wrappers in-place, by possibly adding children to the wrappers in the
595
601
list.
596
- Returns a list with the modified wrappers.
602
+ Returns a list with the (now modified) wrappers.
597
603
598
604
Each round does the following:
599
605
1. Resolve any conflicts using the conflict resolver. Two subgroups at the same nesting
@@ -618,13 +624,7 @@ def _resolve_subgroups(
618
624
# times.
619
625
subgroup_choice_parser = argparse .ArgumentParser (
620
626
add_help = False ,
621
- # conflict_resolution=self.conflict_resolution,
622
- # add_option_string_dash_variants=self.add_option_string_dash_variants,
623
- # argument_generation_mode=self.argument_generation_mode,
624
- # nested_mode=self.nested_mode,
625
627
formatter_class = self .formatter_class ,
626
- # add_config_path_arg=self.add_config_path_arg,
627
- # config_path=self.config_path,
628
628
# NOTE: We disallow abbreviations for subgroups for now. This prevents potential issues
629
629
# for example if you have —a_or_b and A has a field —a then it will error out if you
630
630
# pass —a=1 because 1 isn’t a choice for the a_or_b argument (because --a matches it
@@ -644,10 +644,27 @@ def _resolve_subgroups(
644
644
flags = subgroup_field .option_strings
645
645
argument_options = subgroup_field .arg_options
646
646
647
+ # Sanity checks:
647
648
if subgroup_field .subgroup_default is dataclasses .MISSING :
648
649
assert argument_options ["required" ]
650
+ elif isinstance (argument_options ["default" ], dict ):
651
+ # BUG #276: The default here is a dict because it came from a config file.
652
+ # Here we want the subgroup field to have a 'str' default, because we just want
653
+ # to be able to choose between the subgroup names.
654
+ _default = argument_options ["default" ]
655
+ _default_key = _infer_subgroup_key_to_use_from_config (
656
+ default_in_config = _default ,
657
+ # subgroup_default=subgroup_field.subgroup_default,
658
+ subgroup_choices = subgroup_field .subgroup_choices ,
659
+ )
660
+ # We'd like this field to (at least temporarily) have a different default
661
+ # value that is the subgroup key instead of the dictionary.
662
+ argument_options ["default" ] = _default_key
663
+
649
664
else :
650
- assert argument_options ["default" ] is subgroup_field .subgroup_default
665
+ assert (
666
+ argument_options ["default" ] is subgroup_field .subgroup_default
667
+ ), argument_options ["default" ]
651
668
assert not is_dataclass_instance (argument_options ["default" ])
652
669
653
670
# TODO: Do we really need to care about this "SUPPRESS" stuff here?
@@ -1177,3 +1194,146 @@ def _create_dataclass_instance(
1177
1194
return None
1178
1195
logger .debug (f"Calling constructor: { constructor } (**{ constructor_args } )" )
1179
1196
return constructor (** constructor_args )
1197
+
1198
+
1199
+ def _has_values_of_type (
1200
+ mapping : Mapping [K , Any ], value_type : type [V ] | tuple [type [V ], ...]
1201
+ ) -> TypeGuard [Mapping [K , V ]]:
1202
+ # Utility functions used to narrow the type of dictionaries.
1203
+ return all (isinstance (v , value_type ) for v in mapping .values ())
1204
+
1205
+
1206
+ def _has_keys_of_type (
1207
+ mapping : Mapping [Any , V ], key_type : type [K ] | tuple [type [K ], ...]
1208
+ ) -> TypeGuard [Mapping [K , V ]]:
1209
+ # Utility functions used to narrow the type of dictionaries.
1210
+ return all (isinstance (k , key_type ) for k in mapping .keys ())
1211
+
1212
+
1213
+ def _has_items_of_type (
1214
+ mapping : Mapping [Any , Any ],
1215
+ item_type : tuple [type [K ] | tuple [type [K ], ...], type [V ] | tuple [type [V ], ...]],
1216
+ ) -> TypeGuard [Mapping [K , V ]]:
1217
+ # Utility functions used to narrow the type of a dictionary or mapping.
1218
+ key_type , value_type = item_type
1219
+ return _has_keys_of_type (mapping , key_type ) and _has_values_of_type (mapping , value_type )
1220
+
1221
+
1222
+ def _infer_subgroup_key_to_use_from_config (
1223
+ default_in_config : dict [str , Any ],
1224
+ # subgroup_default: Hashable,
1225
+ subgroup_choices : Mapping [SubgroupKey , type [Dataclass ] | functools .partial [Dataclass ]],
1226
+ ) -> SubgroupKey :
1227
+ config_default = default_in_config
1228
+
1229
+ if SUBGROUP_KEY_FLAG in default_in_config :
1230
+ return default_in_config [SUBGROUP_KEY_FLAG ]
1231
+
1232
+ for subgroup_key , subgroup_value in subgroup_choices .items ():
1233
+ if default_in_config == subgroup_value :
1234
+ return subgroup_key
1235
+
1236
+ assert (
1237
+ DC_TYPE_KEY in config_default
1238
+ ), f"FIXME: assuming that the { DC_TYPE_KEY } is in the config dict."
1239
+ _default_type_name : str = config_default [DC_TYPE_KEY ]
1240
+
1241
+ if _has_values_of_type (subgroup_choices , type ) and all (
1242
+ dataclasses .is_dataclass (subgroup_option ) for subgroup_option in subgroup_choices .values ()
1243
+ ):
1244
+ # Simpler case: All the subgroup options are dataclass types. We just get the key that
1245
+ # matches the type that was saved in the config dict.
1246
+ subgroup_keys_with_value_matching_config_default_type : list [SubgroupKey ] = [
1247
+ k
1248
+ for k , v in subgroup_choices .items ()
1249
+ if (isinstance (v , type ) and f"{ v .__module__ } .{ v .__qualname__ } " == _default_type_name )
1250
+ ]
1251
+ # NOTE: There could be duplicates I guess? Something like `subgroups({"a": A, "aa": A})`
1252
+ assert len (subgroup_keys_with_value_matching_config_default_type ) >= 1
1253
+ return subgroup_keys_with_value_matching_config_default_type [0 ]
1254
+
1255
+ # IDEA: Try to find the best subgroup key to use, based on the number of matching constructor
1256
+ # arguments between the default in the config and the defaults for each subgroup.
1257
+ constructor_args_in_each_subgroup = {
1258
+ key : _default_constructor_argument_values (subgroup_value )
1259
+ for key , subgroup_value in subgroup_choices .items ()
1260
+ }
1261
+ n_matching_values = {
1262
+ k : _num_matching_values (config_default , constructor_args_in_subgroup_value )
1263
+ for k , constructor_args_in_subgroup_value in constructor_args_in_each_subgroup .items ()
1264
+ }
1265
+ closest_subgroups_first = sorted (
1266
+ subgroup_choices .keys (),
1267
+ key = n_matching_values .__getitem__ ,
1268
+ reverse = True ,
1269
+ )
1270
+ warnings .warn (
1271
+ # TODO: Return the dataclass type instead, and be done with it!
1272
+ RuntimeWarning (
1273
+ f"TODO: The config file contains a default value for a subgroup that isn't in the "
1274
+ f"dict of subgroup options. Because of how subgroups are currently implemented, we "
1275
+ f"need to find the key in the subgroup choice dict ({ subgroup_choices } ) that most "
1276
+ f"closely matches the value { config_default } ."
1277
+ f"The current implementation tries to use the dataclass type of this closest match "
1278
+ f"to parse the additional values from the command-line. "
1279
+ f"{ default_in_config } . Consider adding the "
1280
+ f"{ SUBGROUP_KEY_FLAG } : <key of the subgroup to use>"
1281
+ )
1282
+ )
1283
+ return closest_subgroups_first [0 ]
1284
+ return closest_subgroups_first [0 ]
1285
+
1286
+ sorted (
1287
+ [k for k , v in subgroup_choices .items ()],
1288
+ key = _num_matching_values ,
1289
+ reversed = True ,
1290
+ )
1291
+ # _default_values = copy.deepcopy(config_default)
1292
+ # _default_values.pop(DC_TYPE_KEY)
1293
+
1294
+ # default_constructor_args_for_each_subgroup = {
1295
+ # k: _default_constructor_argument_values(dc_type) if dataclasses.is_dataclass(dc_type)
1296
+ # }
1297
+
1298
+
1299
+ def _default_constructor_argument_values (
1300
+ some_dataclass_type : type [Dataclass ] | functools .partial [Dataclass ],
1301
+ ) -> PossiblyNestedDict [str , Any ]:
1302
+ result = {}
1303
+ if isinstance (some_dataclass_type , functools .partial ) and is_dataclass_type (
1304
+ some_dataclass_type .func
1305
+ ):
1306
+ constructor_arguments_from_classdef = _default_constructor_argument_values (
1307
+ some_dataclass_type .func
1308
+ )
1309
+ # TODO: will probably raise an error!
1310
+ constructor_arguments_from_partial = (
1311
+ inspect .signature (some_dataclass_type .func )
1312
+ .bind_partial (* some_dataclass_type .args , ** some_dataclass_type .keywords )
1313
+ .arguments
1314
+ )
1315
+ constructor_arguments_from_classdef .update (constructor_arguments_from_partial )
1316
+ return constructor_arguments_from_classdef
1317
+
1318
+ assert is_dataclass_type (some_dataclass_type )
1319
+ for field in dataclasses .fields (some_dataclass_type ):
1320
+ key = field .name
1321
+ if field .default is not dataclasses .MISSING :
1322
+ result [key ] = field .default
1323
+ elif is_dataclass_type (field .type ) or (
1324
+ isinstance (field .default_factory , functools .partial )
1325
+ and dataclasses .is_dataclass (field .default_factory .func )
1326
+ ):
1327
+ result [key ] = _default_constructor_argument_values (field .type )
1328
+ return result
1329
+
1330
+
1331
+ def _num_matching_values (subgroup_default : dict [str , Any ], subgroup_choice : dict [str , Any ]) -> int :
1332
+ """Returns the number of matching entries in the subgroup dict w/ the default from the
1333
+ config."""
1334
+ return sum (
1335
+ _num_matching_values (default_v , subgroup_choice [k ])
1336
+ if isinstance (subgroup_choice .get (k ), dict ) and isinstance (default_v , dict )
1337
+ else int (subgroup_choice .get (k ) == default_v )
1338
+ for k , default_v in subgroup_default .items ()
1339
+ )
0 commit comments