Open
Description
Describe the bug
Subgroups work if parameters are provided as CLI args, but it fails if the same parameter is passed through a config file. See below:
To Reproduce
from typing import Union
import dataclasses
import simple_parsing
@dataclasses.dataclass
class ModelTypeA:
model_a_param: str
@dataclasses.dataclass
class ModelTypeB:
model_b_param: str
@dataclasses.dataclass
class TrainConfig:
model_type: Union[ModelTypeA, ModelTypeB] = simple_parsing.subgroups(
{"type_a": ModelTypeA, "type_b": ModelTypeB},
default_factory=ModelTypeA,
positional=False,
)
# this works:
simple_parsing.parse(config_class=TrainConfig, args=['--model_a_param', 'test'])
# this doesn't work and throws an error
simple_parsing.parse(config_class=TrainConfig, add_config_path_arg=True, args=['--config_path', 'config.yaml'])
# config.yaml:
# model_a_param: test
Expected behavior
I expect the config file to be allow the same behaviour as the CLI args.
Actual behavior
As described above, the second invocation leads to the following error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/farzad/Library/Caches/pypoetry/virtualenvs/ultravox-o1nHeuex-py3.11/lib/python3.11/site-packages/simple_parsing/parsing.py", line 1030, in parse
parsed_args = parser.parse_args(args)
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/farzad/.pyenv/versions/3.11.4/lib/python3.11/argparse.py", line 1869, in parse_args
args, argv = self.parse_known_args(args, namespace)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/farzad/Library/Caches/pypoetry/virtualenvs/ultravox-o1nHeuex-py3.11/lib/python3.11/site-packages/simple_parsing/parsing.py", line 321, in parse_known_args
self.set_defaults(config_file)
File "/Users/farzad/Library/Caches/pypoetry/virtualenvs/ultravox-o1nHeuex-py3.11/lib/python3.11/site-packages/simple_parsing/parsing.py", line 405, in set_defaults
wrapper.set_default(default_for_dataclass)
File "/Users/farzad/Library/Caches/pypoetry/virtualenvs/ultravox-o1nHeuex-py3.11/lib/python3.11/site-packages/simple_parsing/wrappers/dataclass_wrapper.py", line 313, in set_default
raise RuntimeError(
RuntimeError: ['model_a_param'] are not fields of <class '__main__.TrainConfig'> at path 'config'!
Desktop (please complete the following information):
- Simple parsing version: 0.1.6
- Python version: 3.11.4
- Machine: M1 Mac
Metadata
Metadata
Assignees
Labels
No labels