Skip to content

Commit f5104a5

Browse files
leor-caraffin
andauthored
Allow to set a device when loading a model (#154)
* Added a 'device' keyword argument to BaseAlgorithm.load(). Edited the save and load test to also test the load method with all possible devices. Added the changes to the changelog * improved the load test to ensure that the model loads to the correct device. * improved the test: now the correctness is improved. If the get_device policy would change, it wouldn't break the test. * Update tests/test_save_load.py @araffin's suggestion during the PR process Co-authored-by: Antonin RAFFIN <[email protected]> * Update tests/test_save_load.py Co-authored-by: Antonin RAFFIN <[email protected]> * Bug fixes: when comparing devices, comparing only device type since get_device() doesn't provide device index. Now the code loads all of the model parameters from the saved state dict straight into the required device. (fixed load_from_zip_file). * PR fixes: bug fix - a non-related test failed when running on GPU. updated the assertion to consider only types of devices. Also corrected a related bug in 'get_device()' method. * Update changelog.rst Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 583d4b8 commit f5104a5

File tree

6 files changed

+38
-18
lines changed

6 files changed

+38
-18
lines changed

docs/misc/changelog.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ New Features:
1414
^^^^^^^^^^^^^
1515
- Added ``unwrap_vec_wrapper()`` to ``common.vec_env`` to extract ``VecEnvWrapper`` if needed
1616
- Added ``StopTrainingOnMaxEpisodes`` to callback collection (@xicocaio)
17+
- Added ``device`` keyword argument to ``BaseAlgorithm.load()`` (@liorcohen5)
1718
- Callbacks have access to rollout collection locals as in SB2. (@PartiallyTyped)
1819

1920
Bug Fixes:
2021
^^^^^^^^^^
2122
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``
2223
- Fix logging of ``clip_fraction`` in PPO (@diditforlulz273)
24+
- Fixed a bug where cuda support was wrongly checked when passing the GPU index, e.g., ``device="cuda:0"`` (@liorcohen5)
2325

2426
Deprecations:
2527
^^^^^^^^^^^^^
@@ -399,4 +401,4 @@ And all the contributors:
399401
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
400402
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
401403
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
402-
@diditforlulz273
404+
@diditforlulz273 @liorcohen5

stable_baselines3/common/base_class.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,16 +316,19 @@ def predict(
316316
return self.policy.predict(observation, state, mask, deterministic)
317317

318318
@classmethod
319-
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAlgorithm":
319+
def load(
320+
cls, load_path: str, env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", **kwargs
321+
) -> "BaseAlgorithm":
320322
"""
321323
Load the model from a zip-file
322324
323325
:param load_path: the location of the saved data
324326
:param env: the new environment to run the loaded model on
325327
(can be None if you only need prediction from a trained model) has priority over any saved environment
328+
:param device: (Union[th.device, str]) Device on which the code should run.
326329
:param kwargs: extra arguments to change the model when loading
327330
"""
328-
data, params, tensors = load_from_zip_file(load_path)
331+
data, params, tensors = load_from_zip_file(load_path, device=device)
329332

330333
if "policy_kwargs" in data:
331334
for arg_to_remove in ["device"]:
@@ -352,7 +355,7 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAl
352355
model = cls(
353356
policy=data["policy_class"],
354357
env=env,
355-
device="auto",
358+
device=device,
356359
_init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
357360
)
358361

stable_baselines3/common/save_util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0)
352352
def load_from_zip_file(
353353
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
354354
load_data: bool = True,
355+
device: Union[th.device, str] = "auto",
355356
verbose=0,
356357
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
357358
"""
@@ -360,13 +361,14 @@ def load_from_zip_file(
360361
:param load_path: (str, pathlib.Path, io.BufferedIOBase) Where to load the model from
361362
:param load_data: Whether we should load and return data
362363
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
364+
:param device: (Union[th.device, str]) Device on which the code should run.
363365
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
364366
and dict of extra tensors
365367
"""
366368
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")
367369

368370
# set device to cpu if cuda is not available
369-
device = get_device()
371+
device = get_device(device=device)
370372

371373
# Open the zip archive and load data
372374
try:

stable_baselines3/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
145145
device = th.device(device)
146146

147147
# Cuda not available
148-
if device == th.device("cuda") and not th.cuda.is_available():
148+
if device.type == th.device("cuda").type and not th.cuda.is_available():
149149
return th.device("cpu")
150150

151151
return device

tests/test_predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_predict(model_class, env_id, device):
4646
# Test detection of different shapes by the predict method
4747
model = model_class("MlpPolicy", env_id, device=device)
4848
# Check that the policy is on the right device
49-
assert get_device(device) == model.policy.device
49+
assert get_device(device).type == model.policy.device.type
5050

5151
env = gym.make(env_id)
5252
vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)])

tests/test_save_load.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from stable_baselines3.common.base_class import BaseAlgorithm
1414
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
1515
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
16+
from stable_baselines3.common.utils import get_device
1617
from stable_baselines3.common.vec_env import DummyVecEnv
1718

1819
MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG]
@@ -70,21 +71,33 @@ def test_save_load(tmp_path, model_class):
7071
# Check
7172
model.save(tmp_path / "test_save.zip")
7273
del model
73-
model = model_class.load(str(tmp_path / "test_save.zip"), env=env)
7474

75-
# check if params are still the same after load
76-
new_params = model.policy.state_dict()
75+
# Check if the model loads as expected for every possible choice of device:
76+
for device in ["auto", "cpu", "cuda"]:
77+
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device)
7778

78-
# Check that all params are the same as before save load procedure now
79-
for key in params:
80-
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
79+
# check if the model was loaded to the correct device
80+
assert model.device.type == get_device(device).type
81+
assert model.policy.device.type == get_device(device).type
8182

82-
# check if model still selects the same actions
83-
new_selected_actions, _ = model.predict(observations, deterministic=True)
84-
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
83+
# check if params are still the same after load
84+
new_params = model.policy.state_dict()
8585

86-
# check if learn still works
87-
model.learn(total_timesteps=1000, eval_freq=500)
86+
# Check that all params are the same as before save load procedure now
87+
for key in params:
88+
assert new_params[key].device.type == get_device(device).type
89+
assert th.allclose(
90+
params[key].to("cpu"), new_params[key].to("cpu")
91+
), "Model parameters not the same after save and load."
92+
93+
# check if model still selects the same actions
94+
new_selected_actions, _ = model.predict(observations, deterministic=True)
95+
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
96+
97+
# check if learn still works
98+
model.learn(total_timesteps=1000, eval_freq=500)
99+
100+
del model
88101

89102
# clear file from os
90103
os.remove(tmp_path / "test_save.zip")

0 commit comments

Comments
 (0)