-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add NStepReplayBuffer #2144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add NStepReplayBuffer #2144
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds a new NStepReplayBuffer
to compute n-step returns and accompanying tests to validate its behavior under various termination conditions.
- Implements
NStepReplayBuffer
subclass with overridden_get_samples
to handle multi-step returns, terminations, and truncations. - Provides unit tests in
tests/test_n_step_replay.py
covering normal sampling, early termination, and truncation. - Integrates
NStepReplayBuffer
into DQN/SAC viatest_run
demonstration.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
tests/test_n_step_replay.py | Added comprehensive tests for n-step buffer behavior |
stable_baselines3/common/buffers.py | Introduced NStepReplayBuffer with custom sampling logic |
Comments suppressed due to low confidence (2)
stable_baselines3/common/buffers.py:843
- Add a class-level docstring for
NStepReplayBuffer
to describe its purpose, parameters (n_steps
,gamma
), and behavior (handling terminations and truncations).
class NStepReplayBuffer(ReplayBuffer):
tests/test_n_step_replay.py:10
- The tests cover single-environment behavior but don’t validate multi-environment buffers. Add a test with
n_envs > 1
to ensure correct sampling and offset handling across parallel envs.
@pytest.mark.parametrize("model_class", [SAC, DQN])
safe_timeouts = self.timeouts.copy() | ||
safe_timeouts[self.pos - 1, :] = np.logical_not(self.dones[self.pos - 1, :]) | ||
|
||
indices = (batch_inds[:, None] + steps) % self.buffer_size # shape: [batch, n_steps] | ||
|
||
# Retrieve sequences of transitions | ||
rewards_seq = self.rewards[indices, env_indices[:, None]] # [batch, n_steps] | ||
dones_seq = self.dones[indices, env_indices[:, None]] # [batch, n_steps] | ||
truncs_seq = safe_timeouts[indices, env_indices[:, None]] # [batch, n_steps] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copying the entire timeouts
array on every sample can be expensive for large buffers. Consider only handling the specific index or using a more targeted approach to avoid full-array duplications.
safe_timeouts = self.timeouts.copy() | |
safe_timeouts[self.pos - 1, :] = np.logical_not(self.dones[self.pos - 1, :]) | |
indices = (batch_inds[:, None] + steps) % self.buffer_size # shape: [batch, n_steps] | |
# Retrieve sequences of transitions | |
rewards_seq = self.rewards[indices, env_indices[:, None]] # [batch, n_steps] | |
dones_seq = self.dones[indices, env_indices[:, None]] # [batch, n_steps] | |
truncs_seq = safe_timeouts[indices, env_indices[:, None]] # [batch, n_steps] | |
# Handle specific index without copying the entire array | |
modified_timeouts = np.logical_not(self.dones[self.pos - 1, :]) | |
indices = (batch_inds[:, None] + steps) % self.buffer_size # shape: [batch, n_steps] | |
# Retrieve sequences of transitions | |
rewards_seq = self.rewards[indices, env_indices[:, None]] # [batch, n_steps] | |
dones_seq = self.dones[indices, env_indices[:, None]] # [batch, n_steps] | |
truncs_seq = self.timeouts[indices, env_indices[:, None]] # [batch, n_steps] | |
# Apply the modified value for the specific index | |
truncs_seq[indices == (self.pos - 1)] = modified_timeouts[env_indices[indices == (self.pos - 1)]] |
Copilot uses AI. Check for mistakes.
Description
Reviving #81
closes #47
Idea based on https://github.com/younggyoseo/FastTD3 implementation with fixes to avoid younggyoseo/FastTD3#6
Mostly tested on IsaacSim so far.
To try it out (I'm thinking about adding a
n_steps
param to off-policy algorithm to make it easier):Motivation and Context
Types of changes
Checklist
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line