TQC
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics (TQC). Truncated Quantile Critics (TQC) builds on SAC, TD3 and QR-DQN, making use of quantile regression to predict a distribution for the value function (instead of a mean value). It truncates the quantiles predicted by different networks (a bit as it is done in TD3).
Available Policies
alias of |
|
Policy class (with both actor and critic) for TQC. |
|
|
Policy class (with both actor and critic) for TQC. |
Notes
Original paper: https://arxiv.org/abs/2005.04269
Original Implementation: https://github.com/bayesgroup/tqc_pytorch
Can I use?
Recurrent policies: ❌
Multi processing: ✔️
Gym spaces:
Space |
Action |
Observation |
---|---|---|
Discrete |
❌ |
✔️ |
Box |
✔️ |
✔️ |
MultiDiscrete |
❌ |
✔️ |
MultiBinary |
❌ |
✔️ |
Dict |
❌ |
✔️ |
Example
import gymnasium as gym
import numpy as np
from sb3_contrib import TQC
env = gym.make("Pendulum-v1", render_mode="human")
policy_kwargs = dict(n_critics=2, n_quantiles=25)
model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs)
model.learn(total_timesteps=10_000, log_interval=4)
model.save("tqc_pendulum")
del model # remove to demonstrate saving and loading
model = TQC.load("tqc_pendulum")
obs, _ = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
env.render()
if terminated or truncated:
obs, _ = env.reset()
Results
Result on the PyBullet benchmark (1M steps) and on BipedalWalkerHardcore-v3 (2M steps) using 3 seeds. The complete learning curves are available in the associated PR.
The main difference with SAC is on harder environments (BipedalWalkerHardcore, Walker2D).
Note
Hyperparameters from the gSDE paper were used (as they are tuned for SAC on PyBullet envs), including using gSDE for the exploration and not the unstructured Gaussian noise but this should not affect results in simulation.
Note
We are using the open source PyBullet environments and not the MuJoCo simulator (as done in the original paper). You can find a complete benchmark on PyBullet envs in the gSDE paper if you want to compare TQC results to those of A2C/PPO/SAC/TD3.
Environments |
SAC |
TQC |
---|---|---|
gSDE |
gSDE |
|
HalfCheetah |
2984 +/- 202 |
3041 +/- 157 |
Ant |
3102 +/- 37 |
3700 +/- 37 |
Hopper |
2262 +/- 1 |
2401 +/- 62* |
Walker2D |
2136 +/- 67 |
2535 +/- 94 |
BipedalWalkerHardcore |
13 +/- 18 |
228 +/- 18 |
* with tuned hyperparameter top_quantiles_to_drop_per_net
taken from the original paper
How to replicate the results?
Clone RL-Zoo and checkout the branch feat/tqc
:
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
git checkout feat/tqc
Run the benchmark (replace $ENV_ID
by the envs mentioned above):
python train.py --algo tqc --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Plot the results:
python scripts/all_plots.py -a tqc -e HalfCheetah Ant Hopper Walker2D BipedalWalkerHardcore -f logs/ -o logs/tqc_results
python scripts/plot_from_file.py -i logs/tqc_results.pkl -latex -l TQC
Parameters
- class sb3_contrib.tqc.TQC(policy, env, learning_rate=0.0003, buffer_size=1000000, learning_starts=100, batch_size=256, tau=0.005, gamma=0.99, train_freq=1, gradient_steps=1, action_noise=None, replay_buffer_class=None, replay_buffer_kwargs=None, optimize_memory_usage=False, ent_coef='auto', target_update_interval=1, target_entropy='auto', top_quantiles_to_drop_per_net=2, use_sde=False, sde_sample_freq=-1, use_sde_at_warmup=False, stats_window_size=100, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)[source]
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics. Paper: https://arxiv.org/abs/2005.04269 This implementation uses SB3 SAC implementation as base.
- Parameters:
policy (TQCPolicy) – The policy model to use (MlpPolicy, CnnPolicy, …)
env (Env | VecEnv | str) – The environment to learn from (if registered in Gym, can be str)
learning_rate (float | Callable) – learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0)
buffer_size (int) – size of the replay buffer
learning_starts (int) – how many steps of the model to collect transitions for before learning starts
batch_size (int) – Minibatch size for each gradient update
tau (float) – the soft update coefficient (“Polyak update”, between 0 and 1)
gamma (float) – the discount factor
train_freq (int | Tuple[int, str]) – Update the model every
train_freq
steps. Alternatively pass a tuple of frequency and unit like(5, "step")
or(2, "episode")
.gradient_steps (int) – How many gradient update after each step
action_noise (ActionNoise | None) – the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type.
replay_buffer_class (Type[ReplayBuffer] | None) – Replay buffer class to use (for instance
HerReplayBuffer
). IfNone
, it will be automatically selected.replay_buffer_kwargs (Dict[str, Any] | None) – Keyword arguments to pass to the replay buffer on creation.
optimize_memory_usage (bool) – Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
ent_coef (str | float) – Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. Set it to ‘auto’ to learn it automatically (and ‘auto_0.1’ for using 0.1 as initial value)
target_update_interval (int) – update the target network every
target_network_update_freq
gradient steps.target_entropy (str | float) – target entropy when learning
ent_coef
(ent_coef = 'auto'
)top_quantiles_to_drop_per_net (int) – Number of quantiles to drop per network
use_sde (bool) – Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False)
sde_sample_freq (int) – Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)
use_sde_at_warmup (bool) – Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts)
stats_window_size (int) – Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over
tensorboard_log (str | None) – the log location for tensorboard (if None, no logging)
policy_kwargs (Dict[str, Any] | None) – additional arguments to be passed to the policy on creation
verbose (int) – the verbosity level: 0 no output, 1 info, 2 debug
seed (int | None) – Seed for the pseudo random generators
device (device | str) – Device (cpu, cuda, …) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible.
_init_setup_model (bool) – Whether or not to build the network at the creation of the instance
- collect_rollouts(env, callback, train_freq, replay_buffer, action_noise=None, learning_starts=0, log_interval=None)
Collect experiences and store them into a
ReplayBuffer
.- Parameters:
env (VecEnv) – The training environment
callback (BaseCallback) – Callback that will be called at each step (and at the beginning and end of the rollout)
train_freq (TrainFreq) – How much experience to collect by doing rollouts of current policy. Either
TrainFreq(<n>, TrainFrequencyUnit.STEP)
orTrainFreq(<n>, TrainFrequencyUnit.EPISODE)
with<n>
being an integer greater than 0.action_noise (ActionNoise | None) – Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC.
learning_starts (int) – Number of steps before learning for the warm-up phase.
replay_buffer (ReplayBuffer) –
log_interval (int | None) – Log data every
log_interval
episodes
- Returns:
- Return type:
RolloutReturn
- get_env()
Returns the current environment (can be None if not defined).
- Returns:
The current environment
- Return type:
VecEnv | None
- get_parameters()
Return the parameters of the agent. This includes parameters from different networks, e.g. critics (value functions) and policies (pi functions).
- Returns:
Mapping of from names of the objects to PyTorch state-dicts.
- Return type:
Dict[str, Dict]
- get_vec_normalize_env()
Return the
VecNormalize
wrapper of the training env if it exists.- Returns:
The
VecNormalize
env.- Return type:
VecNormalize | None
- learn(total_timesteps, callback=None, log_interval=4, tb_log_name='TQC', reset_num_timesteps=True, progress_bar=False)[source]
Return a trained model.
- Parameters:
total_timesteps (int) – The total number of samples (env steps) to train on
callback (None | Callable | List[BaseCallback] | BaseCallback) – callback(s) called at every step with state of the algorithm.
log_interval (int) – for on-policy algos (e.g., PPO, A2C, …) this is the number of training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging; for off-policy algos (e.g., TD3, SAC, …) this is the number of episodes before logging.
tb_log_name (str) – the name of the run for TensorBoard logging
reset_num_timesteps (bool) – whether or not to reset the current timestep number (used in logging)
progress_bar (bool) – Display a progress bar using tqdm and rich.
self (SelfTQC) –
- Returns:
the trained model
- Return type:
SelfTQC
- classmethod load(path, env=None, device='auto', custom_objects=None, print_system_info=False, force_reset=True, **kwargs)
Load the model from a zip-file. Warning:
load
re-creates the model from scratch, it does not update it in-place! For an in-place load useset_parameters
instead.- Parameters:
path (str | Path | BufferedIOBase) – path to the file (or a file-like) where to load the agent from
env (Env | VecEnv | None) – the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment
device (device | str) – Device on which the code should run.
custom_objects (Dict[str, Any] | None) – Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in
keras.models.load_model
. Useful when you have an object in file that can not be deserialized.print_system_info (bool) – Whether to print system info from the saved model and the current system info (useful to debug loading issues)
force_reset (bool) – Force call to
reset()
before training to avoid unexpected behavior. See https://github.com/DLR-RM/stable-baselines3/issues/597kwargs – extra arguments to change the model when loading
- Returns:
new model instance with loaded parameters
- Return type:
SelfBaseAlgorithm
- load_replay_buffer(path, truncate_last_traj=True)
Load a replay buffer from a pickle file.
- Parameters:
path (str | Path | BufferedIOBase) – Path to the pickled replay buffer.
truncate_last_traj (bool) – When using
HerReplayBuffer
with online sampling: If set toTrue
, we assume that the last trajectory in the replay buffer was finished (and truncate it). If set toFalse
, we assume that we continue the same trajectory (same episode).
- Return type:
None
- property logger: Logger
Getter for the logger object.
- predict(observation, state=None, episode_start=None, deterministic=False)
Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images).
- Parameters:
observation (ndarray | Dict[str, ndarray]) – the input observation
state (Tuple[ndarray, ...] | None) – The last hidden states (can be None, used in recurrent policies)
episode_start (ndarray | None) – The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset.
deterministic (bool) – Whether or not to return deterministic actions.
- Returns:
the model’s action and the next hidden state (used in recurrent policies)
- Return type:
Tuple[ndarray, Tuple[ndarray, …] | None]
- save(path, exclude=None, include=None)
Save all the attributes of the object and the model parameters in a zip-file.
- Parameters:
path (str | Path | BufferedIOBase) – path to the file where the rl agent should be saved
exclude (Iterable[str] | None) – name of parameters that should be excluded in addition to the default ones
include (Iterable[str] | None) – name of parameters that might be excluded but should be included anyway
- Return type:
None
- save_replay_buffer(path)
Save the replay buffer as a pickle file.
- Parameters:
path (str | Path | BufferedIOBase) – Path to the file where the replay buffer should be saved. if path is a str or pathlib.Path, the path is automatically created if necessary.
- Return type:
None
- set_env(env, force_reset=True)
Checks the validity of the environment, and if it is coherent, set it as the current environment. Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space
- Parameters:
env (Env | VecEnv) – The environment for learning a policy
force_reset (bool) – Force call to
reset()
before training to avoid unexpected behavior. See issue https://github.com/DLR-RM/stable-baselines3/issues/597
- Return type:
None
- set_logger(logger)
Setter for for logger object.
Warning
When passing a custom logger object, this will overwrite
tensorboard_log
andverbose
settings passed to the constructor.- Parameters:
logger (Logger) –
- Return type:
None
- set_parameters(load_path_or_dict, exact_match=True, device='auto')
Load parameters from a given zip-file or a nested dictionary containing parameters for different modules (see
get_parameters
).- Parameters:
load_path_or_iter – Location of the saved data (path or file-like, see
save
), or a nested dictionary containing nn.Module parameters used by the policy. The dictionary maps object names to a state-dictionary returned bytorch.nn.Module.state_dict()
.exact_match (bool) – If True, the given parameters should include parameters for each module and each of their parameters, otherwise raises an Exception. If set to False, this can be used to update only specific parameters.
device (device | str) – Device on which the code should run.
load_path_or_dict (str | Dict[str, Tensor]) –
- Return type:
None
- set_random_seed(seed=None)
Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space)
- Parameters:
seed (int | None) –
- Return type:
None
TQC Policies
- sb3_contrib.tqc.MlpPolicy
alias of
TQCPolicy
- class sb3_contrib.tqc.policies.TQCPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, use_sde=False, log_std_init=-3, use_expln=False, clip_mean=2.0, features_extractor_class=<class 'stable_baselines3.common.torch_layers.FlattenExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_quantiles=25, n_critics=2, share_features_extractor=False)[source]
Policy class (with both actor and critic) for TQC.
- Parameters:
observation_space (Space) – Observation space
action_space (Box) – Action space
lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)
net_arch (List[int] | Dict[str, List[int]] | None) – The specification of the policy and value networks.
activation_fn (Type[Module]) – Activation function
use_sde (bool) – Whether to use State Dependent Exploration or not
log_std_init (float) – Initial value for the log standard deviation
use_expln (bool) – Use
expln()
function instead ofexp()
when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice,exp()
is usually enough.clip_mean (float) – Clip the mean output when using gSDE to avoid numerical instability.
features_extractor_class (Type[BaseFeaturesExtractor]) – Features extractor to use.
features_extractor_kwargs (Dict[str, Any] | None) – Keyword arguments to pass to the feature extractor.
normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)
optimizer_class (Type[Optimizer]) – The optimizer to use,
th.optim.Adam
by defaultoptimizer_kwargs (Dict[str, Any] | None) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer
n_quantiles (int) – Number of quantiles for the critic.
n_critics (int) – Number of critic networks to create.
share_features_extractor (bool) – Whether to share or not the features extractor between the actor and the critic (this saves computation time)
- forward(obs, deterministic=False)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
obs (Tensor | Dict[str, Tensor]) –
deterministic (bool) –
- Return type:
Tensor
- reset_noise(batch_size=1)[source]
Sample new weights for the exploration matrix, when using gSDE.
- Parameters:
batch_size (int) –
- Return type:
None
- set_training_mode(mode)[source]
Put the policy in either training or evaluation mode. This affects certain modules, such as batch normalisation and dropout. :param mode: if true, set to training mode, else set to evaluation mode
- Parameters:
mode (bool) –
- Return type:
None
- class sb3_contrib.tqc.CnnPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, use_sde=False, log_std_init=-3, use_expln=False, clip_mean=2.0, features_extractor_class=<class 'stable_baselines3.common.torch_layers.NatureCNN'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_quantiles=25, n_critics=2, share_features_extractor=False)[source]
Policy class (with both actor and critic) for TQC.
- Parameters:
observation_space (Space) – Observation space
action_space (Box) – Action space
lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)
net_arch (List[int] | Dict[str, List[int]] | None) – The specification of the policy and value networks.
activation_fn (Type[Module]) – Activation function
use_sde (bool) – Whether to use State Dependent Exploration or not
log_std_init (float) – Initial value for the log standard deviation
use_expln (bool) – Use
expln()
function instead ofexp()
when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice,exp()
is usually enough.clip_mean (float) – Clip the mean output when using gSDE to avoid numerical instability.
features_extractor_class (Type[BaseFeaturesExtractor]) – Features extractor to use.
normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)
optimizer_class (Type[Optimizer]) – The optimizer to use,
th.optim.Adam
by defaultoptimizer_kwargs (Dict[str, Any] | None) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer
n_quantiles (int) – Number of quantiles for the critic.
n_critics (int) – Number of critic networks to create.
share_features_extractor (bool) – Whether to share or not the features extractor between the actor and the critic (this saves computation time)
features_extractor_kwargs (Dict[str, Any] | None) –
Comments
This implementation is based on SB3 SAC implementation and uses the code from the original TQC implementation for the quantile huber loss.