Skip to content

Use classes instead of lambdas for schedules #2125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 14, 2025

Conversation

akanto
Copy link
Contributor

@akanto akanto commented Apr 27, 2025

…rtable schedules

Previously, using closures (e.g., lambdas) for learning_rate or clip_range caused segmentation faults when loading models across different platforms (e.g., macOS to Linux), because cloudpickle could not safely serialize/deserialize them.

Description

Refactor schedule-related helper functions into proper classes. This ensures full portability and prevents segfaults when
loading models across different operating systems. Introduces ConstantSchedule, CappedLinearSchedule, and
FloatConverterSchedule for supporting portability across different operating systems.

This commit rewrites:

  • constant_fn as a ConstantSchedule class
  • get_schedule_fn as a FloatConverterSchedule class
  • get_linear_fn as a LinearSchedule class

All schedules are now proper callable classes, making them portable and safely pickleable. Old functions are kept (marked as deprecated) for backward compatibility when loading existing models.

Motivation and Context

Fixes cross-platform segmentation faults (#2115) caused by non-portable closures (like lambdas) being serialized into model files. After this change, saved models are robust, portable, and no longer crash at load time if they are moved across different operating systems.

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have opened an associated PR on the SB3-Contrib repository (if necessary)
  • I have opened an associated PR on the RL-Zoo3 repository (if necessary)
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

…non-portable schedules

Previously, using closures (e.g., lambdas) for learning_rate or clip_range caused
segmentation faults when loading models across different platforms (e.g., macOS to Linux), because cloudpickle could not safely serialize/deserialize them.

This commit rewrites:
- `constant_fn` as a `ConstantSchedule` class
- `get_schedule_fn` as a `FloatConverterSchedule` class
- `get_linear_fn` as a `LinearSchedule` class

All schedules are now proper callable classes, making them portable and safely pickleable.
Old functions are kept (marked as  deprecated) for backward compatibility when loading
existing models.
@akanto akanto force-pushed the save-load-portability branch from 63cfb2e to a6d8c07 Compare April 27, 2025 16:01
@araffin araffin changed the title Fixes #2115. Avoid segmentation fault when loading models with non-po… Use classes instead of lambas for schedules Apr 28, 2025
@araffin araffin self-requested a review May 5, 2025 08:06
@@ -273,7 +273,7 @@ def logger(self) -> Logger:

def _setup_lr_schedule(self) -> None:
"""Transform to callable if needed."""
self.lr_schedule = get_schedule_fn(self.learning_rate)
self.lr_schedule = FloatConverterSchedule(self.learning_rate)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather have a new get_schedule helper here because FloatConverterSchedule does more than what its name suggest

PS: you don't have to force push for every edit, you can simply push new commits to the same branch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or maybe rename FloatConverterSchedule to something else (the issue is that Schedule type is already taken... maybe FloatSchedule?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have renamed it to FloatSchedule, as recommended, but maybe a ScheduleWrapper could be a better name, since it is basically a wrapper that ensures that a constant is transformed to Schedule, and ensures that any callable returning float values.

I wanted to change the original logic as minimal as possible, but maybe the original function is not clean enough, since it is doing multiple things.

@@ -78,6 +78,35 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) ->
param_group["lr"] = learning_rate


class FloatConverterSchedule:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class actually does more no? it enforces that we have a schedule object that can be pickled, no?
Maybe rename it to FloatSchedule and update the docstring to explain its utility (in addition to casting to float)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

Copy link
Member

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR =)
Looks good overall, some minor comments only. I'll try to test it myself later.

PS: we would need to update SB3 contrib too after

akanto and others added 2 commits May 10, 2025 14:58
- Renamed FloatConverterSchedule to FloatSchedule to better reflect its purpose.
- Moved parameter documentation to the class-level docstring for proper Sphinx support
@araffin araffin self-requested a review May 12, 2025 18:46
@araffin araffin changed the title Use classes instead of lambas for schedules Use classes instead of lambdas for schedules May 14, 2025
Copy link
Member

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks =)

Once it is merged, could you also open a PR for SB3 contrib and the RL Zoo?

@araffin araffin merged commit f9c4ca5 into DLR-RM:master May 14, 2025
4 checks passed
@akanto
Copy link
Contributor Author

akanto commented May 15, 2025

Thanks, I can send a pull request. I guess a commit is not enough, but the stable-baslines3 release of 2.6.1 is required to update the dependencies: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/setup.py#L70, and after that, I can update all get_schedule_fn etc. functions there.

@araffin
Copy link
Member

araffin commented May 15, 2025

I've released an alpha that you can use:
https://pypi.org/project/stable-baselines3/2.6.1a1/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants