-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Use classes instead of lambdas for schedules #2125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
52b6ad1
to
63cfb2e
Compare
…non-portable schedules Previously, using closures (e.g., lambdas) for learning_rate or clip_range caused segmentation faults when loading models across different platforms (e.g., macOS to Linux), because cloudpickle could not safely serialize/deserialize them. This commit rewrites: - `constant_fn` as a `ConstantSchedule` class - `get_schedule_fn` as a `FloatConverterSchedule` class - `get_linear_fn` as a `LinearSchedule` class All schedules are now proper callable classes, making them portable and safely pickleable. Old functions are kept (marked as deprecated) for backward compatibility when loading existing models.
63cfb2e
to
a6d8c07
Compare
@@ -273,7 +273,7 @@ def logger(self) -> Logger: | |||
|
|||
def _setup_lr_schedule(self) -> None: | |||
"""Transform to callable if needed.""" | |||
self.lr_schedule = get_schedule_fn(self.learning_rate) | |||
self.lr_schedule = FloatConverterSchedule(self.learning_rate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather have a new get_schedule
helper here because FloatConverterSchedule
does more than what its name suggest
PS: you don't have to force push for every edit, you can simply push new commits to the same branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or maybe rename FloatConverterSchedule
to something else (the issue is that Schedule
type is already taken... maybe FloatSchedule
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have renamed it to FloatSchedule, as recommended, but maybe a ScheduleWrapper
could be a better name, since it is basically a wrapper that ensures that a constant is transformed to Schedule
, and ensures that any callable returning float values.
I wanted to change the original logic as minimal as possible, but maybe the original function is not clean enough, since it is doing multiple things.
stable_baselines3/common/utils.py
Outdated
@@ -78,6 +78,35 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> | |||
param_group["lr"] = learning_rate | |||
|
|||
|
|||
class FloatConverterSchedule: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class actually does more no? it enforces that we have a schedule object that can be pickled, no?
Maybe rename it to FloatSchedule
and update the docstring to explain its utility (in addition to casting to float)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the PR =)
Looks good overall, some minor comments only. I'll try to test it myself later.
PS: we would need to update SB3 contrib too after
- Renamed FloatConverterSchedule to FloatSchedule to better reflect its purpose. - Moved parameter documentation to the class-level docstring for proper Sphinx support
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks =)
Once it is merged, could you also open a PR for SB3 contrib and the RL Zoo?
Thanks, I can send a pull request. I guess a commit is not enough, but the stable-baslines3 release of 2.6.1 is required to update the dependencies: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/setup.py#L70, and after that, I can update all |
I've released an alpha that you can use: |
…rtable schedules
Previously, using closures (e.g., lambdas) for learning_rate or clip_range caused segmentation faults when loading models across different platforms (e.g., macOS to Linux), because cloudpickle could not safely serialize/deserialize them.
Description
Refactor schedule-related helper functions into proper classes. This ensures full portability and prevents segfaults when
loading models across different operating systems. Introduces
ConstantSchedule
,CappedLinearSchedule
, andFloatConverterSchedule
for supporting portability across different operating systems.This commit rewrites:
constant_fn
as aConstantSchedule
classget_schedule_fn
as aFloatConverterSchedule
classget_linear_fn
as aLinearSchedule
classAll schedules are now proper callable classes, making them portable and safely pickleable. Old functions are kept (marked as deprecated) for backward compatibility when loading existing models.
Motivation and Context
Fixes cross-platform segmentation faults (#2115) caused by non-portable closures (like lambdas) being serialized into model files. After this change, saved models are robust, portable, and no longer crash at load time if they are moved across different operating systems.
Types of changes
Checklist
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line