Skip to content

Commit 254bb10

Browse files
Gregwararaffin
andauthored
Replacing the policy registry with policy "aliases" (#842)
* Replacing the policy registry with policy "aliases" * Fixing import order and SAC * Changing arg. order to be sure policy_aliases is a kwarg * Import orders * Removing pytype error check * Reformat * Fix alias import * Not using mutable {} as default for policy_aliases * Empty aliases initialization * Using static attributes for policy_aliases * Fixing isort * Fixing back bad merge * Running isort * Fixing aliases for A2C and PPO * Using f-string * Moving policy_aliases definition position * Adding change in the changelog * Update version Co-authored-by: Antonin Raffin <[email protected]>
1 parent 44e53ff commit 254bb10

File tree

16 files changed

+71
-126
lines changed

16 files changed

+71
-126
lines changed

docs/misc/changelog.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ Changelog
44
==========
55

66

7-
Release 1.5.1a0 (WIP)
7+
Release 1.5.1a1 (WIP)
88
---------------------------
99

1010
Breaking Changes:
1111
^^^^^^^^^^^^^^^^^
12+
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
13+
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
1214

1315
New Features:
1416
^^^^^^^^^^^^^

stable_baselines3/a2c/a2c.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.nn import functional as F
66

77
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
8-
from stable_baselines3.common.policies import ActorCriticPolicy
8+
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
99
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
1010
from stable_baselines3.common.utils import explained_variance
1111

@@ -51,6 +51,12 @@ class A2C(OnPolicyAlgorithm):
5151
:param _init_setup_model: Whether or not to build the network at the creation of the instance
5252
"""
5353

54+
policy_aliases: Dict[str, Type[BasePolicy]] = {
55+
"MlpPolicy": ActorCriticPolicy,
56+
"CnnPolicy": ActorCriticCnnPolicy,
57+
"MultiInputPolicy": MultiInputActorCriticPolicy,
58+
}
59+
5460
def __init__(
5561
self,
5662
policy: Union[str, Type[ActorCriticPolicy]],

stable_baselines3/a2c/policies.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,7 @@
11
# This file is here just to define MlpPolicy/CnnPolicy
22
# that work for A2C
3-
from stable_baselines3.common.policies import (
4-
ActorCriticCnnPolicy,
5-
ActorCriticPolicy,
6-
MultiInputActorCriticPolicy,
7-
register_policy,
8-
)
3+
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
94

105
MlpPolicy = ActorCriticPolicy
116
CnnPolicy = ActorCriticCnnPolicy
127
MultiInputPolicy = MultiInputActorCriticPolicy
13-
14-
register_policy("MlpPolicy", ActorCriticPolicy)
15-
register_policy("CnnPolicy", ActorCriticCnnPolicy)
16-
register_policy("MultiInputPolicy", MultiInputPolicy)

stable_baselines3/common/base_class.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from stable_baselines3.common.logger import Logger
1818
from stable_baselines3.common.monitor import Monitor
1919
from stable_baselines3.common.noise import ActionNoise
20-
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
20+
from stable_baselines3.common.policies import BasePolicy
2121
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
2222
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
2323
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
@@ -60,7 +60,6 @@ class BaseAlgorithm(ABC):
6060
:param policy: Policy object
6161
:param env: The environment to learn from
6262
(if registered in Gym, can be str. Can be None for loading trained models)
63-
:param policy_base: The base policy used by this method
6463
:param learning_rate: learning rate for the optimizer,
6564
it can be a function of the current progress remaining (from 1 to 0)
6665
:param policy_kwargs: Additional arguments to be passed to the policy on creation
@@ -83,11 +82,13 @@ class BaseAlgorithm(ABC):
8382
:param supported_action_spaces: The action spaces supported by the algorithm.
8483
"""
8584

85+
# Policy aliases (see _get_policy_from_name())
86+
policy_aliases: Dict[str, Type[BasePolicy]] = {}
87+
8688
def __init__(
8789
self,
8890
policy: Type[BasePolicy],
8991
env: Union[GymEnv, str, None],
90-
policy_base: Type[BasePolicy],
9192
learning_rate: Union[float, Schedule],
9293
policy_kwargs: Optional[Dict[str, Any]] = None,
9394
tensorboard_log: Optional[str] = None,
@@ -101,9 +102,8 @@ def __init__(
101102
sde_sample_freq: int = -1,
102103
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
103104
):
104-
105-
if isinstance(policy, str) and policy_base is not None:
106-
self.policy_class = get_policy_from_name(policy_base, policy)
105+
if isinstance(policy, str):
106+
self.policy_class = self._get_policy_from_name(policy)
107107
else:
108108
self.policy_class = policy
109109

@@ -325,6 +325,23 @@ def _excluded_save_params(self) -> List[str]:
325325
"_custom_logger",
326326
]
327327

328+
def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
329+
"""
330+
Get a policy class from its name representation.
331+
332+
The goal here is to standardize policy naming, e.g.
333+
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
334+
and they receive respective policies that work for them.
335+
336+
:param policy_name: Alias of the policy
337+
:return: A policy class (type)
338+
"""
339+
340+
if policy_name in self.policy_aliases:
341+
return self.policy_aliases[policy_name]
342+
else:
343+
raise ValueError(f"Policy {policy_name} unknown")
344+
328345
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
329346
"""
330347
Get the name of the torch variables that will be saved with

stable_baselines3/common/off_policy_algorithm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
2828
:param policy: Policy object
2929
:param env: The environment to learn from
3030
(if registered in Gym, can be str. Can be None for loading trained models)
31-
:param policy_base: The base policy used by this method
3231
:param learning_rate: learning rate for the optimizer,
3332
it can be a function of the current progress remaining (from 1 to 0)
3433
:param buffer_size: size of the replay buffer
@@ -76,7 +75,6 @@ def __init__(
7675
self,
7776
policy: Type[BasePolicy],
7877
env: Union[GymEnv, str],
79-
policy_base: Type[BasePolicy],
8078
learning_rate: Union[float, Schedule],
8179
buffer_size: int = 1_000_000, # 1e6
8280
learning_starts: int = 100,
@@ -107,7 +105,6 @@ def __init__(
107105
super(OffPolicyAlgorithm, self).__init__(
108106
policy=policy,
109107
env=env,
110-
policy_base=policy_base,
111108
learning_rate=learning_rate,
112109
policy_kwargs=policy_kwargs,
113110
tensorboard_log=tensorboard_log,

stable_baselines3/common/on_policy_algorithm.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from stable_baselines3.common.base_class import BaseAlgorithm
99
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
1010
from stable_baselines3.common.callbacks import BaseCallback
11-
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
11+
from stable_baselines3.common.policies import ActorCriticPolicy
1212
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
1313
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
1414
from stable_baselines3.common.vec_env import VecEnv
@@ -34,7 +34,6 @@ class OnPolicyAlgorithm(BaseAlgorithm):
3434
instead of action noise exploration (default: False)
3535
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
3636
Default: -1 (only sample at the beginning of the rollout)
37-
:param policy_base: The base policy used by this method
3837
:param tensorboard_log: the log location for tensorboard (if None, no logging)
3938
:param create_eval_env: Whether to create a second environment that will be
4039
used for evaluating the agent periodically. (Only available when passing string for the environment)
@@ -62,7 +61,6 @@ def __init__(
6261
max_grad_norm: float,
6362
use_sde: bool,
6463
sde_sample_freq: int,
65-
policy_base: Type[BasePolicy] = ActorCriticPolicy,
6664
tensorboard_log: Optional[str] = None,
6765
create_eval_env: bool = False,
6866
monitor_wrapper: bool = True,
@@ -77,7 +75,6 @@ def __init__(
7775
super(OnPolicyAlgorithm, self).__init__(
7876
policy=policy,
7977
env=env,
80-
policy_base=policy_base,
8178
learning_rate=learning_rate,
8279
policy_kwargs=policy_kwargs,
8380
verbose=verbose,

stable_baselines3/common/policies.py

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -894,68 +894,3 @@ def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
894894
with th.no_grad():
895895
features = self.extract_features(obs)
896896
return self.q_networks[0](th.cat([features, actions], dim=1))
897-
898-
899-
_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]]
900-
901-
902-
def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]:
903-
"""
904-
Returns the registered policy from the base type and name.
905-
See `register_policy` for registering policies and explanation.
906-
907-
:param base_policy_type: the base policy class
908-
:param name: the policy name
909-
:return: the policy
910-
"""
911-
if base_policy_type not in _policy_registry:
912-
raise KeyError(f"Error: the policy type {base_policy_type} is not registered!")
913-
if name not in _policy_registry[base_policy_type]:
914-
raise KeyError(
915-
f"Error: unknown policy type {name},"
916-
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!"
917-
)
918-
return _policy_registry[base_policy_type][name]
919-
920-
921-
def register_policy(name: str, policy: Type[BasePolicy]) -> None:
922-
"""
923-
Register a policy, so it can be called using its name.
924-
e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...).
925-
926-
The goal here is to standardize policy naming, e.g.
927-
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
928-
and they receive respective policies that work for them.
929-
Consider following:
930-
931-
OnlinePolicy
932-
-- OnlineMlpPolicy ("MlpPolicy")
933-
-- OnlineCnnPolicy ("CnnPolicy")
934-
OfflinePolicy
935-
-- OfflineMlpPolicy ("MlpPolicy")
936-
-- OfflineCnnPolicy ("CnnPolicy")
937-
938-
Two policies have name "MlpPolicy" and two have "CnnPolicy".
939-
In `get_policy_from_name`, the parent class (e.g. OnlinePolicy)
940-
is given and used to select and return the correct policy.
941-
942-
:param name: the policy name
943-
:param policy: the policy class
944-
"""
945-
sub_class = None
946-
for cls in BasePolicy.__subclasses__():
947-
if issubclass(policy, cls):
948-
sub_class = cls
949-
break
950-
if sub_class is None:
951-
raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!")
952-
953-
if sub_class not in _policy_registry:
954-
_policy_registry[sub_class] = {}
955-
if name in _policy_registry[sub_class]:
956-
# Check if the registered policy is same
957-
# we try to register. If not so,
958-
# do not override and complain.
959-
if _policy_registry[sub_class][name] != policy:
960-
raise ValueError(f"Error: the name {name} is already registered for a different policy, will not override.")
961-
_policy_registry[sub_class][name] = policy

stable_baselines3/dqn/dqn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
from stable_baselines3.common.buffers import ReplayBuffer
1010
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
11+
from stable_baselines3.common.policies import BasePolicy
1112
from stable_baselines3.common.preprocessing import maybe_transpose
1213
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
1314
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
14-
from stable_baselines3.dqn.policies import DQNPolicy
15+
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy
1516

1617

1718
class DQN(OffPolicyAlgorithm):
@@ -59,6 +60,12 @@ class DQN(OffPolicyAlgorithm):
5960
:param _init_setup_model: Whether or not to build the network at the creation of the instance
6061
"""
6162

63+
policy_aliases: Dict[str, Type[BasePolicy]] = {
64+
"MlpPolicy": MlpPolicy,
65+
"CnnPolicy": CnnPolicy,
66+
"MultiInputPolicy": MultiInputPolicy,
67+
}
68+
6269
def __init__(
6370
self,
6471
policy: Union[str, Type[DQNPolicy]],
@@ -91,7 +98,6 @@ def __init__(
9198
super(DQN, self).__init__(
9299
policy,
93100
env,
94-
DQNPolicy,
95101
learning_rate,
96102
buffer_size,
97103
learning_starts,

stable_baselines3/dqn/policies.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch as th
55
from torch import nn
66

7-
from stable_baselines3.common.policies import BasePolicy, register_policy
7+
from stable_baselines3.common.policies import BasePolicy
88
from stable_baselines3.common.torch_layers import (
99
BaseFeaturesExtractor,
1010
CombinedExtractor,
@@ -296,8 +296,3 @@ def __init__(
296296
optimizer_class,
297297
optimizer_kwargs,
298298
)
299-
300-
301-
register_policy("MlpPolicy", MlpPolicy)
302-
register_policy("CnnPolicy", CnnPolicy)
303-
register_policy("MultiInputPolicy", MultiInputPolicy)

stable_baselines3/ppo/policies.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,7 @@
11
# This file is here just to define MlpPolicy/CnnPolicy
22
# that work for PPO
3-
from stable_baselines3.common.policies import (
4-
ActorCriticCnnPolicy,
5-
ActorCriticPolicy,
6-
MultiInputActorCriticPolicy,
7-
register_policy,
8-
)
3+
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
94

105
MlpPolicy = ActorCriticPolicy
116
CnnPolicy = ActorCriticCnnPolicy
127
MultiInputPolicy = MultiInputActorCriticPolicy
13-
14-
register_policy("MlpPolicy", ActorCriticPolicy)
15-
register_policy("CnnPolicy", ActorCriticCnnPolicy)
16-
register_policy("MultiInputPolicy", MultiInputPolicy)

stable_baselines3/ppo/ppo.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.nn import functional as F
88

99
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
10-
from stable_baselines3.common.policies import ActorCriticPolicy
10+
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
1111
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
1212
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
1313

@@ -65,6 +65,12 @@ class PPO(OnPolicyAlgorithm):
6565
:param _init_setup_model: Whether or not to build the network at the creation of the instance
6666
"""
6767

68+
policy_aliases: Dict[str, Type[BasePolicy]] = {
69+
"MlpPolicy": ActorCriticPolicy,
70+
"CnnPolicy": ActorCriticCnnPolicy,
71+
"MultiInputPolicy": MultiInputActorCriticPolicy,
72+
}
73+
6874
def __init__(
6975
self,
7076
policy: Union[str, Type[ActorCriticPolicy]],

stable_baselines3/sac/policies.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch import nn
77

88
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
9-
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
9+
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
1010
from stable_baselines3.common.preprocessing import get_action_dim
1111
from stable_baselines3.common.torch_layers import (
1212
BaseFeaturesExtractor,
@@ -514,8 +514,3 @@ def __init__(
514514
n_critics,
515515
share_features_extractor,
516516
)
517-
518-
519-
register_policy("MlpPolicy", MlpPolicy)
520-
register_policy("CnnPolicy", CnnPolicy)
521-
register_policy("MultiInputPolicy", MultiInputPolicy)

stable_baselines3/sac/sac.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from stable_baselines3.common.buffers import ReplayBuffer
99
from stable_baselines3.common.noise import ActionNoise
1010
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
11+
from stable_baselines3.common.policies import BasePolicy
1112
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
1213
from stable_baselines3.common.utils import polyak_update
13-
from stable_baselines3.sac.policies import SACPolicy
14+
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
1415

1516

1617
class SAC(OffPolicyAlgorithm):
@@ -72,6 +73,12 @@ class SAC(OffPolicyAlgorithm):
7273
:param _init_setup_model: Whether or not to build the network at the creation of the instance
7374
"""
7475

76+
policy_aliases: Dict[str, Type[BasePolicy]] = {
77+
"MlpPolicy": MlpPolicy,
78+
"CnnPolicy": CnnPolicy,
79+
"MultiInputPolicy": MultiInputPolicy,
80+
}
81+
7582
def __init__(
7683
self,
7784
policy: Union[str, Type[SACPolicy]],
@@ -106,7 +113,6 @@ def __init__(
106113
super(SAC, self).__init__(
107114
policy,
108115
env,
109-
SACPolicy,
110116
learning_rate,
111117
buffer_size,
112118
learning_starts,

0 commit comments

Comments
 (0)