Skip to content

[Bug]: loading saved model: Missing key(s) in state_dict: "q_net.q_net.0.weight" #2118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
5 tasks done
JonDum opened this issue Apr 12, 2025 · 4 comments
Closed
5 tasks done
Labels
bug Something isn't working

Comments

@JonDum
Copy link

JonDum commented Apr 12, 2025

🐛 Bug

Hello, I hope you're having a fine day.

I seem to be running into an error while attempting to load from checkpoints.

I've tried diagnosing it myself and reduced the complexity down to the bare defaults to rule out issues with it being custom policy_kwargs, I even cleared out my pip/uv cache and reinstalled dependencies from scratch... but I'm still getting the error so I'm a bit lost now. If anyone has some pointers as to what it could be I'd be very appreciative.

To Reproduce

import dotenv
import rich.traceback as rich
import os
import wandb
import torch
import git
from os import environ
from datetime import datetime
from stable_baselines3 import PPO, DQN
from tap import Tap
from typing import Literal


from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import (
    CallbackList,
)
from typing import Optional


dotenv.load_dotenv()
rich.install(show_locals=True)


class ArgumentParser(Tap):
    model: Literal["DQN", "PPO"] = "DQN"
    """Which model to use"""
    device: Literal["cpu", "cuda"] = "cpu"
    """Run on cpu only or use cuda"""
    total_timesteps: int = int(environ.get("STOP_ITERS", 100_000_000))
    """Number of steps to sample before stopping."""
    num_envs: int = int(environ.get("NUM_ENVS", 2))
    """How many parallel envs to sample from"""
    tensorboard_log_dir: str = str(environ.get("TENSORBOARD_LOG_DIR", "logs"))
    """ TB log dir """
    checkpoint_freq: int = int(environ.get("CHECKPOINT_FREQ", 250_000))
    """ Save checkpoint every N total steps """
    checkpoint_dir: str = str(environ.get("CHECKPOINT_LOG_DIR", "checkpoints"))
    """ Where to store checkpoints """
    load_checkpoint: Optional[str] = None
    """ Load from previous checkpoint before training """


if __name__ == "__main__":
    args = ArgumentParser().parse_args()

    env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)

    if args.model == "DQN":

        if args.load_checkpoint:
            model = DQN.load(args.load_checkpoint, env=env,
                             print_system_info=True)
        else:
            policy_kwargs = dict(
                activation_fn=torch.nn.LeakyReLU,
                net_arch=[1024, 1024],
            )
            model = DQN(
                "MlpPolicy",
                env=env,
                verbose=1,
                tensorboard_log=args.tensorboard_log_dir,
                policy_kwargs=policy_kwargs, # I've tried with and without any custom args
                device=args.device,
                batch_size=256,
                # start at only 0.2 epsilon
                exploration_initial_eps=0.2,
                # Explore for 50% of training (default is 0.1)
                exploration_fraction=0.5,
            )

    repo = git.Repo(search_parent_directories=True)
    sha = repo.head.object.hexsha[:8]
    exp_name = f"{args.model}-{sha}-{datetime.now().isoformat()}"
    # Save a checkpoint every 1000 steps
    checkpoint_callback = CheckpointCallback(
        save_freq=args.checkpoint_freq,
        save_path="./checkpoints/",
        name_prefix=exp_name,
        verbose=True
    )

    callbacks = [checkpoint_callback]

    # Compile the policy network
    model.policy = torch.compile(model.policy)

    with Timer():
        model.learn(
            total_timesteps=args.total_timesteps,
            callback=CallbackList(callbacks),
            tb_log_name=exp_name,
            progress_bar=True
        )
        path = f"{args.checkpoint_dir}/{exp_name}_{args.total_timesteps}_steps.zip"
        model.save(path)
        print(
            f"Model saved to {path}")

    env.close()  # Close the environment

Relevant log output / Error message

RuntimeError: Error(s) in loading state_dict for DQNPolicy:
        Missing key(s) in state_dict: "q_net.q_net.0.weight", "q_net.q_net.0.bias", "q_net.q_net.2.weight", "q_net.q_net.2.bias", "q_net.q_net.4.weight", "q_net.q_net.4.bias",
"q_net_target.q_net.0.weight", "q_net_target.q_net.0.bias", "q_net_target.q_net.2.weight", "q_net_target.q_net.2.bias", "q_net_target.q_net.4.weight", "q_net_target.q_net.4.bias".
        Unexpected key(s) in state_dict: "_orig_mod.q_net.q_net.0.weight", "_orig_mod.q_net.q_net.0.bias", "_orig_mod.q_net.q_net.2.weight", "_orig_mod.q_net.q_net.2.bias",
"_orig_mod.q_net.q_net.4.weight", "_orig_mod.q_net.q_net.4.bias", "_orig_mod.q_net_target.q_net.0.weight", "_orig_mod.q_net_target.q_net.0.bias", "_orig_mod.q_net_target.q_net.2.weight",
"_orig_mod.q_net_target.q_net.2.bias", "_orig_mod.q_net_target.q_net.4.weight", "_orig_mod.q_net_target.q_net.4.bias".

System Info

Installed using uv

pyproject.toml

	"dotenv>=0.9.9",
	"gitpython>=3.1.44",
	"gputil>=1.4.0",
	"gymnasium>=1.1.1",
	"hydra-core>=1.3.2",
	"msgpack>=1.1.0",
	"rich>=14.0.0",
	"stable-baselines3>=2.6.0",
	"tensorboard>=2.19.0",
	"torchrl>=0.7.2",
	"tqdm>=4.67.1",
	"typed-argument-parser>=1.10.1",
	"wandb>=0.19.9",

system info

- OS: Linux-6.8.0-57-generic-x86_64-with-glibc2.39 # 59-Ubuntu SMP PREEMPT_DYNAMIC Sat Mar 15 17:40:59 UTC 2025
- Python: 3.12.3
- Stable-Baselines3: 2.6.0
- PyTorch: 2.6.0+cu124
- GPU Enabled: True
- Numpy: 2.2.4
- Cloudpickle: 3.1.1
- Gymnasium: 1.1.1

Checklist

  • My issue does not relate to a custom gym environment. (Use the custom gym env template instead)
  • I have checked that there is no similar issue in the repo
  • I have read the documentation
  • I have provided a minimal and working example to reproduce the bug
  • I've used the markdown code blocks for both code and stack traces.
@JonDum JonDum added the bug Something isn't working label Apr 12, 2025
@araffin
Copy link
Member

araffin commented Apr 13, 2025

hello,
the compile line is your problem.

@JonDum
Copy link
Author

JonDum commented Apr 13, 2025

Hi @araffin thank you! That was indeed the culprit. For my reference and anyone stumbling upon this from google, is there a recommended way of using compile with SB3 w/ saving/loading models? Or is it just not compatible at all?

From a little bit of research it seems like there are ways of still using compile by rewriting the network dicts back to the uncompiled versions when checkpointing 🤔

@araffin
Copy link
Member

araffin commented Apr 14, 2025

is there a recommended way of using compile with SB3 w/ saving/loading models? Or is it just not compatible at all?

I didn't have time so far to solve that problem properly.
If you are looking for a speed boost, you can have a look at SBX: https://github.com/araffin/sbx (less features but faster because of Jax)

From a little bit of research it seems like there are ways of still using compile by rewriting the network dicts back to the uncompiled versions when checkpointing 🤔

Could you link those resources here? might be helpful for others.

@araffin araffin closed this as completed May 9, 2025
@araffin
Copy link
Member

araffin commented May 14, 2025

see #2137 for a fix
related to #1438

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants