-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Update documentation #199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Update documentation #199
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
e134a67
Update doc and add new example
araffin 57724dd
Add save/load replay buffer example
araffin 4e7155a
Add save format + export doc
araffin 844b913
Add example for get/set parameters
araffin 3573abc
Merge branch 'master' into doc/additional-doc
araffin d613f18
Typos and minor edits
araffin 8778dbc
Add results sections
araffin 4e7578e
Add note about performance
araffin 11d26ab
Add DDPG results
araffin 7338771
Merge branch 'master' into doc/additional-doc
araffin bf52c60
Merge branch 'master' into doc/additional-doc
araffin 47baaca
Address comments
araffin 9013f97
Merge branch 'doc/additional-doc' of github.com:DLR-RM/stable-baselin…
araffin f11db60
Fix grammar/wording
Miffyli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
.. _export: | ||
|
||
|
||
Exporting models | ||
================ | ||
|
||
After training an agent, you may want to deploy/use it in another language | ||
or framework, like `tensorflowjs <https://github.com/tensorflow/tfjs>`_. | ||
Stable Baselines3 does not include tools to export models to other frameworks, but | ||
this document aims to cover parts that are required for exporting along with | ||
more detailed stories from users of Stable Baselines3. | ||
|
||
|
||
Background | ||
---------- | ||
|
||
In Stable Baselines3, the controller is stored inside policies which convert | ||
observations into actions. Each learning algorithm (e.g. DQN, A2C, SAC) | ||
contains a policy object which represents the currently learned behavior, | ||
accessible via ``model.policy``. | ||
|
||
Policies hold enough information to do the inference (i.e. predict actions), | ||
so it is enough to export these policies (cf :ref:`examples <examples>`) | ||
to do inference in another framework. | ||
|
||
.. warning:: | ||
When using CNN policies, the observation is normalized during pre-preprocessing. | ||
This pre-processing is done *inside* the policy (dividing by 255 to have values in [0, 1]) | ||
|
||
|
||
Export to ONNX | ||
----------------- | ||
|
||
TODO: help is welcomed! | ||
|
||
|
||
Export to C++ | ||
----------------- | ||
|
||
(using PyTorch JIT) | ||
TODO: help is welcomed! | ||
|
||
|
||
Export to tensorflowjs / ONNX-JS | ||
-------------------------------- | ||
|
||
TODO: contributors help is welcomed! | ||
Probably a good starting point: https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js | ||
|
||
|
||
|
||
Manual export | ||
------------- | ||
|
||
You can also manually export required parameters (weights) and construct the | ||
network in your desired framework. | ||
|
||
You can access parameters of the model via agents' | ||
:func:`get_parameters <stable_baselines3.common.base_class.BaseAlgorithm.get_parameters>` function. | ||
As policies are also PyTorch modules, you can also access ``model.policy.state_dict()`` directly. | ||
To find the architecture of the networks for each algorithm, best is to check the ``policies.py`` file located | ||
in their respective folders. | ||
|
||
.. note:: | ||
|
||
In most cases, we recommend using PyTorch methods ``state_dict()`` and ``load_state_dict()`` from the policy, | ||
unless you need to access the optimizers' state dict too. In that case, you need to call ``get_parameters()``. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
.. _save_format: | ||
|
||
|
||
On saving and loading | ||
===================== | ||
|
||
Stable Baselines3 (SB3) stores both neural network parameters and algorithm-related parameters such as | ||
exploration schedule, number of environments and observation/action space. This allows continual learning and easy | ||
use of trained agents without training, but it is not without its issues. Following describes the format | ||
used to save agents in SB3 along with its pros and shortcomings. | ||
|
||
Terminology used in this page: | ||
|
||
- *parameters* refer to neural network parameters (also called "weights"). This is a dictionary | ||
mapping variable name to a PyTorch tensor. | ||
- *data* refers to RL algorithm parameters, e.g. learning rate, exploration schedule, action/observation space. | ||
These depend on the algorithm used. This is a dictionary mapping classes variable names to their values. | ||
|
||
|
||
Zip-archive | ||
----------- | ||
|
||
A zip-archived JSON dump, PyTorch state dictionaries and PyTorch variables. The data dictionary (class parameters) | ||
is stored as a JSON file, model parameters and optimizers are serialized with ``torch.save()`` function and these files | ||
are stored under a single .zip archive. | ||
|
||
Any objects that are not JSON serializable are serialized with cloudpickle and stored as base64-encoded | ||
string in the JSON file, along with some information that was stored in the serialization. This allows | ||
inspecting stored objects without deserializing the object itself. | ||
|
||
This format allows skipping elements in the file, i.e. we can skip deserializing objects that are | ||
broken/non-serializable. | ||
|
||
.. This can be done via ``custom_objects`` argument to load functions. | ||
|
||
|
||
File structure: | ||
|
||
:: | ||
|
||
saved_model.zip/ | ||
├── data JSON file of class-parameters (dictionary) | ||
├── *.optimizer.pth PyTorch optimizers serialized | ||
├── policy.pth PyTorch state dictionary of the policy saved | ||
├── pytorch_variables.pth Additional PyTorch variables | ||
├── _stable_baselines3_version contains the SB3 version with which the model was saved | ||
|
||
|
||
Pros: | ||
|
||
- More robust to unserializable objects (one bad object does not break everything). | ||
- Saved files can be inspected/extracted with zip-archive explorers and by other languages. | ||
|
||
|
||
Cons: | ||
|
||
- More complex implementation. | ||
- Still relies partly on cloudpickle for complex objects (e.g. custom functions) | ||
with can lead to `incompatibilities <https://github.com/DLR-RM/stable-baselines3/issues/172>`_ between Python versions. |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.