Skip to content

Add test for GAE + rename RolloutBuffer.dones for clarification #375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 16, 2021
Merged
9 changes: 6 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ Changelog
==========


Release 1.1.0a3 (WIP)
Release 1.1.0a4 (WIP)
---------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Renamed ``_last_dones`` and ``dones`` to ``_last_episode_starts`` and ``episode_starts`` in ``RolloutBuffer``.

New Features:
^^^^^^^^^^^^^
Expand All @@ -30,15 +31,17 @@ Others:
^^^^^^^
- Added ``flake8-bugbear`` to tests dependencies to find likely bugs
- Added Code of Conduct
- Added tests for GAE and lambda return computation

Documentation:
^^^^^^^^^^^^^^
- Added gym pybullet drones project (@JacopoPan)
- Added link to SuperSuit in projects (@justinkterry)
- Fixed DQN example (thanks @ltbd78)
- Clarify channel-first/channel-last recommendation
- Clarified channel-first/channel-last recommendation
- Update sphinx environment installation instructions (@tom-doerr)
- Clarify pip installation in Zsh (@tom-doerr)
- Clarified pip installation in Zsh (@tom-doerr)
- Clarified return computation for on-policy algorithms (TD(lambda) estimate was used)
- Added example for using ``ProcgenEnv``


Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
self.tensorboard_log = tensorboard_log
self.lr_schedule = None # type: Optional[Schedule]
self._last_obs = None # type: Optional[np.ndarray]
self._last_dones = None # type: Optional[np.ndarray]
self._last_episode_starts = None # type: Optional[np.ndarray]
# When using VecNormalize:
self._last_original_obs = None # type: Optional[np.ndarray]
self._episode_num = 0
Expand Down Expand Up @@ -377,7 +377,7 @@ def _setup_learn(
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset()
self._last_dones = np.zeros((self.env.num_envs,), dtype=bool)
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs()
Expand Down
37 changes: 25 additions & 12 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def __init__(
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
self.returns, self.dones, self.values, self.log_probs = None, None, None, None
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
self.generator_ready = False
self.reset()

Expand All @@ -303,7 +303,7 @@ def reset(self) -> None:
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
Expand All @@ -312,20 +312,25 @@ def reset(self) -> None:

def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
"""
Post-processing step: compute the returns (sum of discounted rewards)
and GAE advantage.
Adapted from Stable-Baselines PPO2.
Post-processing step: compute the lambda-return (TD(lambda) estimate)
and GAE(lambda) advantage.

Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
to compute the advantage. To obtain vanilla advantage (A(s) = R - V(S))
where R is the discounted reward with value bootstrap,
set ``gae_lambda=1.0`` during initialization.

:param last_values:
:param dones:
The TD(lambda) estimator has also two special cases:
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))

For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.

:param last_values: state value estimation for the last step (one for each env)
:param dones: if the last step was a terminal step (one bool for each env).

"""
# convert to numpy
# Convert to numpy
last_values = last_values.clone().cpu().numpy().flatten()

last_gae_lam = 0
Expand All @@ -334,21 +339,29 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra
next_non_terminal = 1.0 - dones
next_values = last_values
else:
next_non_terminal = 1.0 - self.dones[step + 1]
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_values = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
self.returns = self.advantages + self.values

def add(
self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray, value: th.Tensor, log_prob: th.Tensor
self,
obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action
:param reward:
:param done: End of episode signal.
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
Expand All @@ -366,7 +379,7 @@ def add(
self.observations[self.pos] = np.array(obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.dones[self.pos] = np.array(done).copy()
self.episode_starts[self.pos] = np.array(episode_start).copy()
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ def collect_rollouts(
if isinstance(self.action_space, gym.spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
rollout_buffer.add(self._last_obs, actions, rewards, self._last_dones, values, log_probs)
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
self._last_obs = new_obs
self._last_dones = dones
self._last_episode_starts = dones

with th.no_grad():
# Compute value for the last timestep
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.0a3
1.1.0a4
114 changes: 114 additions & 0 deletions tests/test_gae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import gym
import numpy as np
import pytest
import torch as th

from stable_baselines3 import A2C, PPO
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy


class CustomEnv(gym.Env):
def __init__(self, max_steps=8):
super(CustomEnv, self).__init__()
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.max_steps = max_steps
self.n_steps = 0

def seed(self, seed):
self.observation_space.seed(seed)

def reset(self):
self.n_steps = 0
return self.observation_space.sample()

def step(self, action):
self.n_steps += 1

done = False
reward = 0.0
if self.n_steps >= self.max_steps:
reward = 1.0
done = True

return self.observation_space.sample(), reward, done, {}


class CheckGAECallback(BaseCallback):
def __init__(self):
super(CheckGAECallback, self).__init__(verbose=0)

def _on_rollout_end(self):
buffer = self.model.rollout_buffer
rollout_size = buffer.size()

max_steps = self.training_env.envs[0].max_steps
gamma = self.model.gamma
gae_lambda = self.model.gae_lambda
value = self.model.policy.constant_value
# We know in advance that the agent will get a single
# reward at the very last timestep of the episode,
# so we can pre-compute the lambda-return and advantage
deltas = np.zeros((rollout_size,))
advantages = np.zeros((rollout_size,))
# Reward should be 1.0 on final timestep of episode
rewards = np.zeros((rollout_size,))
rewards[max_steps - 1 :: max_steps] = 1.0
# Note that these are episode starts (+1 timestep from done)
episode_starts = np.zeros((rollout_size,))
episode_starts[::max_steps] = 1.0

# Final step is always terminal (next would episode_start = 1)
deltas[-1] = rewards[-1] - value
advantages[-1] = deltas[-1]
for n in reversed(range(rollout_size - 1)):
# Values are constants
episode_start_mask = 1.0 - episode_starts[n + 1]
deltas[n] = rewards[n] + gamma * value * episode_start_mask - value
advantages[n] = deltas[n] + gamma * gae_lambda * advantages[n + 1] * episode_start_mask

# TD(lambda) estimate, see Github PR #375
lambda_returns = advantages + value

assert np.allclose(buffer.advantages.flatten(), advantages)
assert np.allclose(buffer.returns.flatten(), lambda_returns)

def _on_step(self):
return True


class CustomPolicy(ActorCriticPolicy):
"""Custom Policy with a constant value function"""

def __init__(self, *args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs)
self.constant_value = 0.0

def forward(self, obs, deterministic=False):
actions, values, log_prob = super().forward(obs, deterministic)
# Overwrite values with ones
values = th.ones_like(values) * self.constant_value
return actions, values, log_prob


@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("gae_lambda", [1.0, 0.9])
@pytest.mark.parametrize("gamma", [1.0, 0.99])
@pytest.mark.parametrize("num_episodes", [1, 3])
def test_gae_computation(model_class, gae_lambda, gamma, num_episodes):
env = CustomEnv(max_steps=64)
rollout_size = 64 * num_episodes
model = model_class(
CustomPolicy,
env,
seed=1,
gamma=gamma,
n_steps=rollout_size,
gae_lambda=gae_lambda,
)
model.learn(rollout_size, callback=CheckGAECallback())

# Change constant value so advantage != returns
model.policy.constant_value = 1.0
model.learn(rollout_size, callback=CheckGAECallback())