Skip to content

Commit f9c4ca5

Browse files
akantoaraffin
andauthored
Use classes instead of lambdas for schedules (#2125)
* Fixes #2115. Avoid segmentation fault when loading models with non-portable schedules Previously, using closures (e.g., lambdas) for learning_rate or clip_range caused segmentation faults when loading models across different platforms (e.g., macOS to Linux), because cloudpickle could not safely serialize/deserialize them. This commit rewrites: - `constant_fn` as a `ConstantSchedule` class - `get_schedule_fn` as a `FloatConverterSchedule` class - `get_linear_fn` as a `LinearSchedule` class All schedules are now proper callable classes, making them portable and safely pickleable. Old functions are kept (marked as deprecated) for backward compatibility when loading existing models. * Incorporate pull request comments: - Renamed FloatConverterSchedule to FloatSchedule to better reflect its purpose. - Moved parameter documentation to the class-level docstring for proper Sphinx support * Update changelog and test * Add more tests and deprecate explicitely the lambdas --------- Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 19df267 commit f9c4ca5

File tree

8 files changed

+189
-12
lines changed

8 files changed

+189
-12
lines changed

docs/misc/changelog.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Changelog
44
==========
55

6-
Release 2.6.1a0 (WIP)
6+
Release 2.6.1a1 (WIP)
77
--------------------------
88

99
Breaking Changes:
@@ -15,6 +15,7 @@ New Features:
1515
Bug Fixes:
1616
^^^^^^^^^^
1717
- Fixed docker GPU image (PyTorch GPU was not installed)
18+
- Fixed segmentation faults caused by non-portable schedules during model loading (@akanto)
1819

1920
`SB3-Contrib`_
2021
^^^^^^^^^^^^^^
@@ -27,6 +28,7 @@ Bug Fixes:
2728

2829
Deprecations:
2930
^^^^^^^^^^^^^
31+
- ``get_schedule_fn()``, ``get_linear_fn()``, ``constant_fn()`` are deprecated, please use ``FloatSchedule()``, ``LinearSchedule()``, ``ConstantSchedule()`` instead
3032

3133
Others:
3234
^^^^^^^
@@ -1814,7 +1816,7 @@ Contributors:
18141816
-------------
18151817
In random order...
18161818

1817-
Thanks to the maintainers of V2: @hill-a @enerijunior @AdamGleave @Miffyli
1819+
Thanks to the maintainers of V2: @hill-a @ernestum @AdamGleave @Miffyli
18181820

18191821
And all the contributors:
18201822
@taymuur @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
@@ -1838,4 +1840,4 @@ And all the contributors:
18381840
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
18391841
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
18401842
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
1841-
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen
1843+
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto

stable_baselines3/common/base_class.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
2626
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict
2727
from stable_baselines3.common.utils import (
28+
FloatSchedule,
2829
check_for_correct_spaces,
2930
get_device,
30-
get_schedule_fn,
3131
get_system_info,
3232
set_random_seed,
3333
update_learning_rate,
@@ -273,7 +273,7 @@ def logger(self) -> Logger:
273273

274274
def _setup_lr_schedule(self) -> None:
275275
"""Transform to callable if needed."""
276-
self.lr_schedule = get_schedule_fn(self.learning_rate)
276+
self.lr_schedule = FloatSchedule(self.learning_rate)
277277

278278
def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
279279
"""

stable_baselines3/common/utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import platform
44
import random
55
import re
6+
import warnings
67
from collections import deque
78
from collections.abc import Iterable
89
from itertools import zip_longest
@@ -78,6 +79,84 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) ->
7879
param_group["lr"] = learning_rate
7980

8081

82+
class FloatSchedule:
83+
"""
84+
Wrapper that ensures the output of a Schedule is cast to float.
85+
Can wrap either a constant value or an existing callable Schedule.
86+
87+
:param value_schedule: Constant value or callable schedule
88+
(e.g. LinearSchedule, ConstantSchedule)
89+
"""
90+
91+
def __init__(self, value_schedule: Union[Schedule, float]):
92+
if isinstance(value_schedule, FloatSchedule):
93+
self.value_schedule: Schedule = value_schedule.value_schedule
94+
elif isinstance(value_schedule, (float, int)):
95+
self.value_schedule = ConstantSchedule(float(value_schedule))
96+
else:
97+
assert callable(value_schedule), f"The learning rate schedule must be a float or a callable, not {value_schedule}"
98+
self.value_schedule = value_schedule
99+
100+
def __call__(self, progress_remaining: float) -> float:
101+
# Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900
102+
# Some types are have odd behaviors when part of a Schedule, like numpy floats
103+
return float(self.value_schedule(progress_remaining))
104+
105+
def __repr__(self) -> str:
106+
return f"FloatSchedule({self.value_schedule})"
107+
108+
109+
class LinearSchedule:
110+
"""
111+
LinearSchedule interpolates linearly between start and end
112+
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
113+
This is used in DQN for linearly annealing the exploration fraction
114+
(epsilon for the epsilon-greedy strategy).
115+
116+
:param start: value to start with if ``progress_remaining`` = 1
117+
:param end: value to end with if ``progress_remaining`` = 0
118+
:param end_fraction: fraction of ``progress_remaining`` where end is reached e.g 0.1
119+
then end is reached after 10% of the complete training process.
120+
"""
121+
122+
def __init__(self, start: float, end: float, end_fraction: float) -> None:
123+
self.start = start
124+
self.end = end
125+
self.end_fraction = end_fraction
126+
127+
def __call__(self, progress_remaining: float) -> float:
128+
if (1 - progress_remaining) > self.end_fraction:
129+
return self.end
130+
else:
131+
return self.start + (1 - progress_remaining) * (self.end - self.start) / self.end_fraction
132+
133+
def __repr__(self) -> str:
134+
return f"LinearSchedule(start={self.start}, end={self.end}, end_fraction={self.end_fraction})"
135+
136+
137+
class ConstantSchedule:
138+
"""
139+
Constant schedule that always returns the same value.
140+
Useful for fixed learning rates or clip ranges.
141+
142+
:param val: constant value
143+
"""
144+
145+
def __init__(self, val: float):
146+
self.val = val
147+
148+
def __call__(self, _: float) -> float:
149+
return self.val
150+
151+
def __repr__(self) -> str:
152+
return f"ConstantSchedule(val={self.val})"
153+
154+
155+
# ===== Deprecated schedule functions ====
156+
# only kept for backward compatibility when unpickling old models, use FloatSchedule
157+
# and other classes like `LinearSchedule() instead
158+
159+
81160
def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
82161
"""
83162
Transform (if needed) learning rate and clip range (for PPO)
@@ -86,6 +165,7 @@ def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
86165
:param value_schedule: Constant value of schedule function
87166
:return: Schedule function (can return constant value)
88167
"""
168+
warnings.warn("get_schedule_fn() is deprecated, please use FloatSchedule() instead")
89169
# If the passed schedule is a float
90170
# create a constant function
91171
if isinstance(value_schedule, (float, int)):
@@ -112,6 +192,7 @@ def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
112192
of the complete training process.
113193
:return: Linear schedule function.
114194
"""
195+
warnings.warn("get_linear_fn() is deprecated, please use LinearSchedule() instead")
115196

116197
def func(progress_remaining: float) -> float:
117198
if (1 - progress_remaining) > end_fraction:
@@ -130,13 +211,17 @@ def constant_fn(val: float) -> Schedule:
130211
:param val: constant value
131212
:return: Constant schedule function.
132213
"""
214+
warnings.warn("constant_fn() is deprecated, please use ConstantSchedule() instead")
133215

134216
def func(_):
135217
return val
136218

137219
return func
138220

139221

222+
# ==== End of deprecated schedule functions ====
223+
224+
140225
def get_device(device: Union[th.device, str] = "auto") -> th.device:
141226
"""
142227
Retrieve PyTorch device.

stable_baselines3/dqn/dqn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
1111
from stable_baselines3.common.policies import BasePolicy
1212
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
13-
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
13+
from stable_baselines3.common.utils import LinearSchedule, get_parameters_by_name, polyak_update
1414
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork
1515

1616
SelfDQN = TypeVar("SelfDQN", bound="DQN")
@@ -146,7 +146,7 @@ def _setup_model(self) -> None:
146146
# Copy running stats, see GH issue #996
147147
self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"])
148148
self.batch_norm_stats_target = get_parameters_by_name(self.q_net_target, ["running_"])
149-
self.exploration_schedule = get_linear_fn(
149+
self.exploration_schedule = LinearSchedule(
150150
self.exploration_initial_eps,
151151
self.exploration_final_eps,
152152
self.exploration_fraction,

stable_baselines3/ppo/ppo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
1111
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
1212
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
13-
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
13+
from stable_baselines3.common.utils import FloatSchedule, explained_variance
1414

1515
SelfPPO = TypeVar("SelfPPO", bound="PPO")
1616

@@ -174,12 +174,12 @@ def _setup_model(self) -> None:
174174
super()._setup_model()
175175

176176
# Initialize schedules for policy/value clipping
177-
self.clip_range = get_schedule_fn(self.clip_range)
177+
self.clip_range = FloatSchedule(self.clip_range)
178178
if self.clip_range_vf is not None:
179179
if isinstance(self.clip_range_vf, (float, int)):
180180
assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"
181181

182-
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
182+
self.clip_range_vf = FloatSchedule(self.clip_range_vf)
183183

184184
def train(self) -> None:
185185
"""

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.6.1a0
1+
2.6.1a1

tests/test_save_load.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from stable_baselines3.common.env_util import make_vec_env
2020
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
2121
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
22-
from stable_baselines3.common.utils import get_device
22+
from stable_baselines3.common.utils import ConstantSchedule, FloatSchedule, get_device
2323
from stable_baselines3.common.vec_env import DummyVecEnv
2424

2525
MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG]
@@ -821,3 +821,56 @@ def test_save_load_no_target_params(tmp_path):
821821
with pytest.warns(UserWarning):
822822
DQN.load(str(tmp_path / "test_save.zip"), env=env).learn(20)
823823
os.remove(tmp_path / "test_save.zip")
824+
825+
826+
@pytest.mark.parametrize("model_class", [PPO])
827+
def test_save_load_backward_compatible(tmp_path, model_class):
828+
"""
829+
Test that lambdas are working when saving/loading models.
830+
See GH#2115
831+
"""
832+
833+
env = DummyVecEnv([lambda: IdentityEnvBox(-1, 1)])
834+
835+
model = model_class("MlpPolicy", env, n_steps=64, learning_rate=lambda _: 0.001, clip_range=lambda _: 0.3)
836+
model.learn(total_timesteps=100)
837+
838+
model.save(tmp_path / "test_schedule_safe.zip")
839+
840+
model = model_class.load(tmp_path / "test_schedule_safe.zip", env=env)
841+
842+
assert model.learning_rate(0) == 0.001
843+
assert model.learning_rate.__name__ == "<lambda>"
844+
845+
assert isinstance(model.clip_range, FloatSchedule)
846+
assert model.clip_range.value_schedule(0) == 0.3
847+
848+
849+
@pytest.mark.parametrize("model_class", [PPO])
850+
def test_save_load_clip_range_portable(tmp_path, model_class):
851+
"""
852+
Test that models using callable schedule classes (e.g., ConstantSchedule, LinearSchedule)
853+
are saved and loaded correctly without segfaults across different machines.
854+
855+
This ensures that we don't serialize fragile lambda closures.
856+
See GH#2115
857+
"""
858+
# Create a simple env
859+
env = DummyVecEnv([lambda: IdentityEnvBox(-1, 1)])
860+
861+
model = model_class("MlpPolicy", env)
862+
model.learn(total_timesteps=100)
863+
864+
# Make sure that classes are used not lambdas by default
865+
assert isinstance(model.clip_range, FloatSchedule)
866+
assert isinstance(model.clip_range.value_schedule, ConstantSchedule)
867+
assert model.clip_range.value_schedule.val == 0.2
868+
869+
model.save(tmp_path / "test_schedule_safe.zip")
870+
871+
model = model_class.load(tmp_path / "test_schedule_safe.zip", env=env)
872+
873+
# Check that the model is loaded correctly
874+
assert isinstance(model.clip_range, FloatSchedule)
875+
assert isinstance(model.clip_range.value_schedule, ConstantSchedule)
876+
assert model.clip_range.value_schedule.val == 0.2

tests/test_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,14 @@
1616
from stable_baselines3.common.monitor import Monitor
1717
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
1818
from stable_baselines3.common.utils import (
19+
ConstantSchedule,
20+
FloatSchedule,
21+
LinearSchedule,
1922
check_shape_equal,
23+
constant_fn,
24+
get_linear_fn,
2025
get_parameters_by_name,
26+
get_schedule_fn,
2127
get_system_info,
2228
is_vectorized_observation,
2329
polyak_update,
@@ -593,3 +599,34 @@ def test_check_shape_equal():
593599
space2 = spaces.Dict({"key1": spaces.Box(low=-1, high=2, shape=(3, 3)), "key2": spaces.Box(low=-1, high=2, shape=(2, 2))})
594600
with pytest.raises(AssertionError):
595601
check_shape_equal(space1, space2)
602+
603+
604+
def test_deprecated_schedules():
605+
with pytest.warns(Warning):
606+
get_schedule_fn(0.1)
607+
get_schedule_fn(lambda _: 0.1)
608+
with pytest.warns(Warning):
609+
linear_fn = get_linear_fn(1.0, 0.0, 0.1)
610+
linear_schedule = LinearSchedule(1.0, 0.0, 0.1)
611+
float_schedule = FloatSchedule(linear_schedule)
612+
assert np.allclose(linear_fn(0.95), 0.5)
613+
assert np.allclose(linear_fn(0.95), linear_schedule(0.95))
614+
assert np.allclose(linear_fn(0.95), float_schedule(0.95))
615+
assert np.allclose(linear_fn(0.9), 0.0)
616+
assert np.allclose(linear_fn(0.0), 0.0)
617+
assert np.allclose(linear_fn(0.9), linear_schedule(0.9))
618+
assert np.allclose(linear_fn(0.9), float_schedule(0.9))
619+
with pytest.warns(Warning):
620+
fn = constant_fn(1.0)
621+
schedule = ConstantSchedule(1.0)
622+
float_schedule = FloatSchedule(1.0)
623+
float_schedule_2 = FloatSchedule(float_schedule)
624+
assert id(float_schedule_2.value_schedule) == id(float_schedule.value_schedule)
625+
assert np.allclose(fn(0.0), 1.0)
626+
assert np.allclose(fn(0.0), schedule(0.0))
627+
assert np.allclose(fn(0.0), float_schedule(0.0))
628+
assert np.allclose(fn(0.0), float_schedule_2(0.0))
629+
assert np.allclose(fn(0.5), 1.0)
630+
assert np.allclose(fn(0.5), schedule(0.5))
631+
assert np.allclose(fn(0.5), float_schedule(0.5))
632+
assert np.allclose(fn(0.5), float_schedule_2(0.5))

0 commit comments

Comments
 (0)