Skip to content

Use classes instead of lambdas for schedules #2125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.6.1a0 (WIP)
Release 2.6.1a1 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -15,6 +15,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed docker GPU image (PyTorch GPU was not installed)
- Fixed segmentation faults caused by non-portable schedules during model loading (@akanto)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand All @@ -27,6 +28,7 @@ Bug Fixes:

Deprecations:
^^^^^^^^^^^^^
- ``get_schedule_fn()``, ``get_linear_fn()``, ``constant_fn()`` are deprecated, please use ``FloatSchedule()``, ``LinearSchedule()``, ``ConstantSchedule()`` instead

Others:
^^^^^^^
Expand Down Expand Up @@ -1814,7 +1816,7 @@ Contributors:
-------------
In random order...

Thanks to the maintainers of V2: @hill-a @enerijunior @AdamGleave @Miffyli
Thanks to the maintainers of V2: @hill-a @ernestum @AdamGleave @Miffyli

And all the contributors:
@taymuur @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
Expand All @@ -1838,4 +1840,4 @@ And all the contributors:
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto
4 changes: 2 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict
from stable_baselines3.common.utils import (
FloatSchedule,
check_for_correct_spaces,
get_device,
get_schedule_fn,
get_system_info,
set_random_seed,
update_learning_rate,
Expand Down Expand Up @@ -273,7 +273,7 @@ def logger(self) -> Logger:

def _setup_lr_schedule(self) -> None:
"""Transform to callable if needed."""
self.lr_schedule = get_schedule_fn(self.learning_rate)
self.lr_schedule = FloatSchedule(self.learning_rate)

def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
"""
Expand Down
85 changes: 85 additions & 0 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import platform
import random
import re
import warnings
from collections import deque
from collections.abc import Iterable
from itertools import zip_longest
Expand Down Expand Up @@ -78,6 +79,84 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) ->
param_group["lr"] = learning_rate


class FloatSchedule:
"""
Wrapper that ensures the output of a Schedule is cast to float.
Can wrap either a constant value or an existing callable Schedule.

:param value_schedule: Constant value or callable schedule
(e.g. LinearSchedule, ConstantSchedule)
"""

def __init__(self, value_schedule: Union[Schedule, float]):
if isinstance(value_schedule, FloatSchedule):
self.value_schedule: Schedule = value_schedule.value_schedule
elif isinstance(value_schedule, (float, int)):
self.value_schedule = ConstantSchedule(float(value_schedule))
else:
assert callable(value_schedule), f"The learning rate schedule must be a float or a callable, not {value_schedule}"
self.value_schedule = value_schedule

def __call__(self, progress_remaining: float) -> float:
# Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900
# Some types are have odd behaviors when part of a Schedule, like numpy floats
return float(self.value_schedule(progress_remaining))

def __repr__(self) -> str:
return f"FloatSchedule({self.value_schedule})"


class LinearSchedule:
"""
LinearSchedule interpolates linearly between start and end
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
This is used in DQN for linearly annealing the exploration fraction
(epsilon for the epsilon-greedy strategy).

:param start: value to start with if ``progress_remaining`` = 1
:param end: value to end with if ``progress_remaining`` = 0
:param end_fraction: fraction of ``progress_remaining`` where end is reached e.g 0.1
then end is reached after 10% of the complete training process.
"""

def __init__(self, start: float, end: float, end_fraction: float) -> None:
self.start = start
self.end = end
self.end_fraction = end_fraction

def __call__(self, progress_remaining: float) -> float:
if (1 - progress_remaining) > self.end_fraction:
return self.end
else:
return self.start + (1 - progress_remaining) * (self.end - self.start) / self.end_fraction

def __repr__(self) -> str:
return f"LinearSchedule(start={self.start}, end={self.end}, end_fraction={self.end_fraction})"


class ConstantSchedule:
"""
Constant schedule that always returns the same value.
Useful for fixed learning rates or clip ranges.

:param val: constant value
"""

def __init__(self, val: float):
self.val = val

def __call__(self, _: float) -> float:
return self.val

def __repr__(self) -> str:
return f"ConstantSchedule(val={self.val})"


# ===== Deprecated schedule functions ====
# only kept for backward compatibility when unpickling old models, use FloatSchedule
# and other classes like `LinearSchedule() instead


def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
"""
Transform (if needed) learning rate and clip range (for PPO)
Expand All @@ -86,6 +165,7 @@ def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
:param value_schedule: Constant value of schedule function
:return: Schedule function (can return constant value)
"""
warnings.warn("get_schedule_fn() is deprecated, please use FloatSchedule() instead")
# If the passed schedule is a float
# create a constant function
if isinstance(value_schedule, (float, int)):
Expand All @@ -112,6 +192,7 @@ def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
of the complete training process.
:return: Linear schedule function.
"""
warnings.warn("get_linear_fn() is deprecated, please use LinearSchedule() instead")

def func(progress_remaining: float) -> float:
if (1 - progress_remaining) > end_fraction:
Expand All @@ -130,13 +211,17 @@ def constant_fn(val: float) -> Schedule:
:param val: constant value
:return: Constant schedule function.
"""
warnings.warn("constant_fn() is deprecated, please use ConstantSchedule() instead")

def func(_):
return val

return func


# ==== End of deprecated schedule functions ====


def get_device(device: Union[th.device, str] = "auto") -> th.device:
"""
Retrieve PyTorch device.
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
from stable_baselines3.common.utils import LinearSchedule, get_parameters_by_name, polyak_update
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork

SelfDQN = TypeVar("SelfDQN", bound="DQN")
Expand Down Expand Up @@ -146,7 +146,7 @@ def _setup_model(self) -> None:
# Copy running stats, see GH issue #996
self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"])
self.batch_norm_stats_target = get_parameters_by_name(self.q_net_target, ["running_"])
self.exploration_schedule = get_linear_fn(
self.exploration_schedule = LinearSchedule(
self.exploration_initial_eps,
self.exploration_final_eps,
self.exploration_fraction,
Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
from stable_baselines3.common.utils import FloatSchedule, explained_variance

SelfPPO = TypeVar("SelfPPO", bound="PPO")

Expand Down Expand Up @@ -174,12 +174,12 @@ def _setup_model(self) -> None:
super()._setup_model()

# Initialize schedules for policy/value clipping
self.clip_range = get_schedule_fn(self.clip_range)
self.clip_range = FloatSchedule(self.clip_range)
if self.clip_range_vf is not None:
if isinstance(self.clip_range_vf, (float, int)):
assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"

self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
self.clip_range_vf = FloatSchedule(self.clip_range_vf)

def train(self) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.6.1a0
2.6.1a1
55 changes: 54 additions & 1 deletion tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.utils import ConstantSchedule, FloatSchedule, get_device
from stable_baselines3.common.vec_env import DummyVecEnv

MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG]
Expand Down Expand Up @@ -821,3 +821,56 @@ def test_save_load_no_target_params(tmp_path):
with pytest.warns(UserWarning):
DQN.load(str(tmp_path / "test_save.zip"), env=env).learn(20)
os.remove(tmp_path / "test_save.zip")


@pytest.mark.parametrize("model_class", [PPO])
def test_save_load_backward_compatible(tmp_path, model_class):
"""
Test that lambdas are working when saving/loading models.
See GH#2115
"""

env = DummyVecEnv([lambda: IdentityEnvBox(-1, 1)])

model = model_class("MlpPolicy", env, n_steps=64, learning_rate=lambda _: 0.001, clip_range=lambda _: 0.3)
model.learn(total_timesteps=100)

model.save(tmp_path / "test_schedule_safe.zip")

model = model_class.load(tmp_path / "test_schedule_safe.zip", env=env)

assert model.learning_rate(0) == 0.001
assert model.learning_rate.__name__ == "<lambda>"

assert isinstance(model.clip_range, FloatSchedule)
assert model.clip_range.value_schedule(0) == 0.3


@pytest.mark.parametrize("model_class", [PPO])
def test_save_load_clip_range_portable(tmp_path, model_class):
"""
Test that models using callable schedule classes (e.g., ConstantSchedule, LinearSchedule)
are saved and loaded correctly without segfaults across different machines.

This ensures that we don't serialize fragile lambda closures.
See GH#2115
"""
# Create a simple env
env = DummyVecEnv([lambda: IdentityEnvBox(-1, 1)])

model = model_class("MlpPolicy", env)
model.learn(total_timesteps=100)

# Make sure that classes are used not lambdas by default
assert isinstance(model.clip_range, FloatSchedule)
assert isinstance(model.clip_range.value_schedule, ConstantSchedule)
assert model.clip_range.value_schedule.val == 0.2

model.save(tmp_path / "test_schedule_safe.zip")

model = model_class.load(tmp_path / "test_schedule_safe.zip", env=env)

# Check that the model is loaded correctly
assert isinstance(model.clip_range, FloatSchedule)
assert isinstance(model.clip_range.value_schedule, ConstantSchedule)
assert model.clip_range.value_schedule.val == 0.2
37 changes: 37 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
from stable_baselines3.common.utils import (
ConstantSchedule,
FloatSchedule,
LinearSchedule,
check_shape_equal,
constant_fn,
get_linear_fn,
get_parameters_by_name,
get_schedule_fn,
get_system_info,
is_vectorized_observation,
polyak_update,
Expand Down Expand Up @@ -593,3 +599,34 @@ def test_check_shape_equal():
space2 = spaces.Dict({"key1": spaces.Box(low=-1, high=2, shape=(3, 3)), "key2": spaces.Box(low=-1, high=2, shape=(2, 2))})
with pytest.raises(AssertionError):
check_shape_equal(space1, space2)


def test_deprecated_schedules():
with pytest.warns(Warning):
get_schedule_fn(0.1)
get_schedule_fn(lambda _: 0.1)
with pytest.warns(Warning):
linear_fn = get_linear_fn(1.0, 0.0, 0.1)
linear_schedule = LinearSchedule(1.0, 0.0, 0.1)
float_schedule = FloatSchedule(linear_schedule)
assert np.allclose(linear_fn(0.95), 0.5)
assert np.allclose(linear_fn(0.95), linear_schedule(0.95))
assert np.allclose(linear_fn(0.95), float_schedule(0.95))
assert np.allclose(linear_fn(0.9), 0.0)
assert np.allclose(linear_fn(0.0), 0.0)
assert np.allclose(linear_fn(0.9), linear_schedule(0.9))
assert np.allclose(linear_fn(0.9), float_schedule(0.9))
with pytest.warns(Warning):
fn = constant_fn(1.0)
schedule = ConstantSchedule(1.0)
float_schedule = FloatSchedule(1.0)
float_schedule_2 = FloatSchedule(float_schedule)
assert id(float_schedule_2.value_schedule) == id(float_schedule.value_schedule)
assert np.allclose(fn(0.0), 1.0)
assert np.allclose(fn(0.0), schedule(0.0))
assert np.allclose(fn(0.0), float_schedule(0.0))
assert np.allclose(fn(0.0), float_schedule_2(0.0))
assert np.allclose(fn(0.5), 1.0)
assert np.allclose(fn(0.5), schedule(0.5))
assert np.allclose(fn(0.5), float_schedule(0.5))
assert np.allclose(fn(0.5), float_schedule_2(0.5))