Skip to content

[Bug]: Higher memory usage on sequential training runs #1966

Open
@NickLucche

Description

@NickLucche

🐛 Bug

Hey,
thanks a lot for your work!
I am trying to debug an apparent memory leak/higher memory usage when running the training code multiple times, but I can't pinpoint its cause.
I've boiled down my problem to the snippet below. Basically when starting sequential training runs I get a higher memory consumption than a single one, when I would expect all resources to be released after PPO object is collected.
I believe the only real difference in this example is the obs and action space, which mimics my use case.

Single run memory usage model.learn(total_timesteps=500_000)
image

Multi run memory usage model.learn(total_timesteps=25_000) N times. Crashes early due to OOM.
image

To Reproduce

import gymnasium as gym
from gymnasium.wrappers.time_limit import TimeLimit
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv
from gymnasium.spaces import MultiDiscrete, Box
import numpy as np

OBS_SHAPE = (320, 320, 2)

class DummyEnv(gym.Env):
    def __init__(self, *args, **kwargs):
        super(DummyEnv, self).__init__()
        self.action_space = MultiDiscrete([6 * 8 + 1, 2, 2], dtype=np.uint8)
        self.observation_space = Box(low=0, high=255, shape=OBS_SHAPE, dtype=np.uint8)

    def reset(self, *args, **kwargs):
        return self.observation_space.sample(), {}

    def step(self, action, *args, **kwargs):
        assert self.action_space.contains(action), f"{action} ({type(action)}) invalid"
        state = self.observation_space.sample()
        reward = 0.0
        done = False
        return state, reward, done, False, {}
        

def make_env():
    env = DummyEnv()
    env = TimeLimit(env, 100)
    return Monitor(env)

def train(ts):
    vec_env = SubprocVecEnv([make_env for _ in range(12)])
    model = PPO("CnnPolicy", vec_env, verbose=1)
    model.learn(total_timesteps=ts)
    model.get_env().close()

if __name__ == "__main__":    
    for i in range(20):
        print("Starting", i)
        train(25_000)
        print(i, "finished")
    # train(500_000)

Relevant log output / Error message

No response

System Info

- OS: Linux-6.8.7-arch1-1-x86_64-with-glibc2.39 # 1 SMP PREEMPT_DYNAMIC Wed, 17 Apr 2024 15:20:28 +0000
- Python: 3.11.8
- Stable-Baselines3: 2.3.0
- PyTorch: 2.3.0+cu121
- GPU Enabled: True
- Numpy: 1.26.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1

Checklist

  • My issue does not relate to a custom gym environment. (Use the custom gym env template instead)
  • I have checked that there is no similar issue in the repo
  • I have read the documentation
  • I have provided a minimal and working example to reproduce the bug
  • I've used the markdown code blocks for both code and stack traces.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions