Skip to content

Commit 63cfb2e

Browse files
committed
Fixes DLR-RM#2115. Avoid segmentation fault when loading models with non-portable schedules
Previously, using closures (e.g., lambdas) for learning_rate or clip_range caused segmentation faults when loading models across different platforms (e.g., macOS to Linux), because cloudpickle could not safely serialize/deserialize them. This commit rewrites: - `constant_fn` as a `ConstantSchedule` class - `get_schedule_fn` as a `FloatConverterSchedule` class - `get_linear_fn` as a `LinearSchedule` class All schedules are now proper callable classes, making them portable and safely pickleable. Old functions are kept (marked as deprecated) for backward compatibility when loading existing models.
1 parent c1e503c commit 63cfb2e

File tree

6 files changed

+140
-8
lines changed

6 files changed

+140
-8
lines changed

docs/misc/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ New Features:
1515
Bug Fixes:
1616
^^^^^^^^^^
1717
- Fixed docker GPU image (PyTorch GPU was not installed)
18+
- Fixed segmentation faults caused by non-portable schedules during model loading (#2115)
1819

1920
`SB3-Contrib`_
2021
^^^^^^^^^^^^^^

stable_baselines3/common/base_class.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
2626
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict
2727
from stable_baselines3.common.utils import (
28+
FloatConverterSchedule,
2829
check_for_correct_spaces,
2930
get_device,
30-
get_schedule_fn,
3131
get_system_info,
3232
set_random_seed,
3333
update_learning_rate,
@@ -273,7 +273,7 @@ def logger(self) -> Logger:
273273

274274
def _setup_lr_schedule(self) -> None:
275275
"""Transform to callable if needed."""
276-
self.lr_schedule = get_schedule_fn(self.learning_rate)
276+
self.lr_schedule = FloatConverterSchedule(self.learning_rate)
277277

278278
def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
279279
"""

stable_baselines3/common/utils.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,35 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) ->
7878
param_group["lr"] = learning_rate
7979

8080

81+
class FloatConverterSchedule:
82+
"""
83+
Wrapper that ensures the output of a Schedule is cast to float.
84+
Can wrap either a constant value or an existing callable Schedule.
85+
"""
86+
87+
def __init__(self, value_schedule: Union[Schedule, float]):
88+
"""
89+
:param value_schedule: Constant value or callable schedule
90+
(e.g. LinearSchedule, ConstantSchedule)
91+
"""
92+
if isinstance(value_schedule, FloatConverterSchedule):
93+
self.value_schedule: Schedule = value_schedule.value_schedule
94+
elif isinstance(value_schedule, (float, int)):
95+
self.value_schedule = ConstantSchedule(float(value_schedule))
96+
else:
97+
assert callable(value_schedule)
98+
self.value_schedule = value_schedule
99+
100+
def __call__(self, progress_remaining: float) -> float:
101+
# Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900
102+
# Some types are have odd behaviors when part of a Schedule, like numpy floats
103+
return float(self.value_schedule(progress_remaining))
104+
105+
def __repr__(self) -> str:
106+
return f"FloatConverterSchedule({self.value_schedule})"
107+
108+
109+
# Deprecated: only kept for backward compatibility when unpickling old models, use ScheduleWrapper instead
81110
def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
82111
"""
83112
Transform (if needed) learning rate and clip range (for PPO)
@@ -98,6 +127,37 @@ def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
98127
return lambda progress_remaining: float(value_schedule(progress_remaining))
99128

100129

130+
class LinearSchedule:
131+
"""
132+
LinearSchedule interpolates linearly between start and end
133+
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
134+
This is used in DQN for linearly annealing the exploration fraction
135+
(epsilon for the epsilon-greedy strategy).
136+
"""
137+
138+
def __init__(self, start: float, end: float, end_fraction: float):
139+
"""
140+
:param start: value to start with if ``progress_remaining`` = 1
141+
:param end: value to end with if ``progress_remaining`` = 0
142+
:param end_fraction: fraction of ``progress_remaining``
143+
where end is reached e.g 0.1 then end is reached after 10%
144+
of the complete training process.
145+
"""
146+
self.start = start
147+
self.end = end
148+
self.end_fraction = end_fraction
149+
150+
def __call__(self, progress_remaining: float) -> float:
151+
if (1 - progress_remaining) > self.end_fraction:
152+
return self.end
153+
else:
154+
return self.start + (1 - progress_remaining) * (self.end - self.start) / self.end_fraction
155+
156+
def __repr__(self) -> str:
157+
return f"LinearSchedule(start={self.start}, end={self.end}, end_fraction={self.end_fraction})"
158+
159+
160+
# Deprecated: only kept for backward compatibility when unpickling old models, use LinearSchedule instead
101161
def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
102162
"""
103163
Create a function that interpolates linearly between start and end
@@ -122,6 +182,25 @@ def func(progress_remaining: float) -> float:
122182
return func
123183

124184

185+
class ConstantSchedule:
186+
"""
187+
Constant schedule that always returns the same value.
188+
Useful for fixed learning rates or clip ranges.
189+
190+
:param val: constant value
191+
"""
192+
193+
def __init__(self, val: float):
194+
self.val = val
195+
196+
def __call__(self, _: float) -> float:
197+
return self.val
198+
199+
def __repr__(self) -> str:
200+
return f"ConstantSchedule(val={self.val})"
201+
202+
203+
# Deprecated: only kept for backward compatibility when unpickling old models, use ConstantSchedule instead
125204
def constant_fn(val: float) -> Schedule:
126205
"""
127206
Create a function that returns a constant

stable_baselines3/dqn/dqn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
1111
from stable_baselines3.common.policies import BasePolicy
1212
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
13-
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
13+
from stable_baselines3.common.utils import LinearSchedule, get_parameters_by_name, polyak_update
1414
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork
1515

1616
SelfDQN = TypeVar("SelfDQN", bound="DQN")
@@ -146,7 +146,7 @@ def _setup_model(self) -> None:
146146
# Copy running stats, see GH issue #996
147147
self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"])
148148
self.batch_norm_stats_target = get_parameters_by_name(self.q_net_target, ["running_"])
149-
self.exploration_schedule = get_linear_fn(
149+
self.exploration_schedule = LinearSchedule(
150150
self.exploration_initial_eps,
151151
self.exploration_final_eps,
152152
self.exploration_fraction,

stable_baselines3/ppo/ppo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
1111
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
1212
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
13-
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
13+
from stable_baselines3.common.utils import FloatConverterSchedule, explained_variance
1414

1515
SelfPPO = TypeVar("SelfPPO", bound="PPO")
1616

@@ -174,12 +174,12 @@ def _setup_model(self) -> None:
174174
super()._setup_model()
175175

176176
# Initialize schedules for policy/value clipping
177-
self.clip_range = get_schedule_fn(self.clip_range)
177+
self.clip_range = FloatConverterSchedule(self.clip_range)
178178
if self.clip_range_vf is not None:
179179
if isinstance(self.clip_range_vf, (float, int)):
180180
assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"
181181

182-
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
182+
self.clip_range_vf = FloatConverterSchedule(self.clip_range_vf)
183183

184184
def train(self) -> None:
185185
"""

tests/test_save_load.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from stable_baselines3.common.env_util import make_vec_env
2020
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
2121
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
22-
from stable_baselines3.common.utils import get_device
22+
from stable_baselines3.common.utils import ConstantSchedule, FloatConverterSchedule, get_device
2323
from stable_baselines3.common.vec_env import DummyVecEnv
2424

2525
MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG]
@@ -821,3 +821,55 @@ def test_save_load_no_target_params(tmp_path):
821821
with pytest.warns(UserWarning):
822822
DQN.load(str(tmp_path / "test_save.zip"), env=env).learn(20)
823823
os.remove(tmp_path / "test_save.zip")
824+
825+
826+
@pytest.mark.parametrize("model_class", [PPO])
827+
def test_save_load_backward_compatible(tmp_path, model_class):
828+
"""
829+
Test that lambdas are working when saving/loading models.
830+
See GH#2115
831+
"""
832+
833+
env = DummyVecEnv([lambda: IdentityEnvBox(-1, 1)])
834+
835+
model = model_class("MlpPolicy", env, learning_rate=lambda _: 0.001, clip_range=lambda _: 0.3)
836+
model.learn(total_timesteps=100)
837+
838+
model.save(tmp_path / "test_schedule_safe.zip")
839+
840+
model = model_class.load(tmp_path / "test_schedule_safe.zip", env=env)
841+
842+
assert model.learning_rate(0) == 0.001
843+
844+
assert isinstance(model.clip_range, FloatConverterSchedule)
845+
assert model.clip_range.value_schedule(0) == 0.3
846+
847+
848+
@pytest.mark.parametrize("model_class", [PPO])
849+
def test_save_load_clip_range_portable(tmp_path, model_class):
850+
"""
851+
Test that models using callable schedule classes (e.g., ConstantSchedule, LinearSchedule)
852+
are saved and loaded correctly without segfaults across different machines.
853+
854+
This ensures that we don't serialize fragile lambda closures.
855+
See GH#2115
856+
"""
857+
# Create a simple env
858+
env = DummyVecEnv([lambda: IdentityEnvBox(-1, 1)])
859+
860+
model = model_class("MlpPolicy", env)
861+
model.learn(total_timesteps=100)
862+
863+
# Make sure that classes are used not lambdas by default
864+
assert isinstance(model.clip_range, FloatConverterSchedule)
865+
assert isinstance(model.clip_range.value_schedule, ConstantSchedule)
866+
assert model.clip_range.value_schedule.val == 0.2
867+
868+
model.save(tmp_path / "test_schedule_safe.zip")
869+
870+
model = model_class.load(tmp_path / "test_schedule_safe.zip", env=env)
871+
872+
# Check that the model is loaded correctly
873+
assert isinstance(model.clip_range, FloatConverterSchedule)
874+
assert isinstance(model.clip_range.value_schedule, ConstantSchedule)
875+
assert model.clip_range.value_schedule.val == 0.2

0 commit comments

Comments
 (0)