Skip to content

[Question] Importing trained model in SB3 to Matlab #2105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
4 tasks done
mohanafathollahi opened this issue Mar 21, 2025 · 2 comments
Closed
4 tasks done

[Question] Importing trained model in SB3 to Matlab #2105

mohanafathollahi opened this issue Mar 21, 2025 · 2 comments
Labels
question Further information is requested

Comments

@mohanafathollahi
Copy link

❓ Question

Hi,

I trained SAC model in SB3 and I defined custom gym environment for that. After training the model, I wanted to import it to Matlab. I considered below approaches:

  1. I used clip action in my custom gym environment. Since ONNX does not support post processing, like clipping, the output of onnx was not correct and applying clip on the generated onnx output gave nonsense result.
  2. SB3 using pytorch and my project is not image classification or segmentation models to be supported by importNetworkFromPyTorchcan in Matlab.
  3. I think Tensorflow does not have this limitation in Matlab. I tried to switch to SB2 to have a model in tensorflow but SB2 needs tensorflow==1.5 and apparently I need to downgrade python version from 3.8 to lower versions.
  4. Manually building the network by optimum weights and biases, which is a bit time consuming because I have 4 models.

Do you have any recommendation to facilitate importing trained model from SB3 to Matlab.

Thank you

Checklist

@mohanafathollahi mohanafathollahi added the question Further information is requested label Mar 21, 2025
@araffin
Copy link
Member

araffin commented Mar 22, 2025

I used clip action in my custom gym environment. Since ONNX does not support post processing, like clipping, the output of onnx was not correct and applying clip on the generated onnx output gave nonsense result.

ONNX cannot trace torch.clamp()?
Also, no clipping is needed for SAC: https://araffin.github.io/post/sac-massive-sim/

SB3 using pytorch and my project is not image classification

SAC actor is just fully connected layers, I'm surprised it would not work, see doc for tracing: https://stable-baselines3.readthedocs.io/en/master/guide/export.html#trace-export-to-c

https://de.mathworks.com/help/deeplearning/ref/importnetworkfrompytorch.html

Manually building the network by optimum weights and biases, which is a bit time consuming because I have 4 models.

That's problably the easiest/fastest. Why 4 models? you only need the actor? (and doing it one time should be same for any SAC trained model).
print(model.policy) to have the architecture: https://stable-baselines3.readthedocs.io/en/master/guide/export.html#manual-export

SAC actor:

class Actor(BasePolicy):

In Jax (more readable): https://github.com/araffin/sbx/blob/8238fccc19048340870e4869813835b8fb02e577/sbx/common/policies.py#L251-L262

@mohanafathollahi
Copy link
Author

Thank you for your response.
I figure out that in the source code, specifically in the common/policies.py file and the predict function, there are additional steps, such as the unscale_action function.
This function rescale the actions produced by the model to fit within the range [low, high], which corresponds to the lower and upper limits of the action space. After applying this function to the output of the ONNX model, I was able to generate the same output that the model.predict function produces.

@araffin araffin closed this as completed Mar 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants