Source code for sb3_contrib.tqc.tqc

from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union

import gym
import numpy as np
import torch as th
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update

from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy

TQCSelf = TypeVar("TQCSelf", bound="TQC")

[docs]class TQC(OffPolicyAlgorithm): """ Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics. Paper: This implementation uses SB3 SAC implementation as base. :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer :param learning_starts: how many steps of the model to collect transitions for before learning starts :param batch_size: Minibatch size for each gradient update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) :param gamma: the discount factor :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit like ``(5, "step")`` or ``(2, "episode")``. :param gradient_steps: How many gradient update after each step :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). If ``None``, it will be automatically selected. :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See :param ent_coef: Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value) :param target_update_interval: update the target network every ``target_network_update_freq`` gradient steps. :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) :param top_quantiles_to_drop_per_net: Number of quantiles to drop per network :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically (Only available when passing string for the environment). Caution, this parameter is deprecated and will be removed in the future. :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ policy_aliases: Dict[str, Type[BasePolicy]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } def __init__( self, policy: Union[str, Type[TQCPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Callable] = 3e-4, buffer_size: int = 1000000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, train_freq: int = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[ReplayBuffer] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, ent_coef: Union[str, float] = "auto", target_update_interval: int = 1, target_entropy: Union[str, float] = "auto", top_quantiles_to_drop_per_net: int = 2, use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): super().__init__( policy, env, learning_rate, buffer_size, learning_starts, batch_size, tau, gamma, train_freq, gradient_steps, action_noise=action_noise, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, device=device, create_eval_env=create_eval_env, seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, optimize_memory_usage=optimize_memory_usage, supported_action_spaces=(gym.spaces.Box), support_multi_env=True, ) self.target_entropy = target_entropy self.log_ent_coef = None # type: Optional[th.Tensor] # Entropy coefficient / Entropy temperature # Inverse of the reward scale self.ent_coef = ent_coef self.target_update_interval = target_update_interval self.ent_coef_optimizer = None self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net if _init_setup_model: self._setup_model() def _setup_model(self) -> None: super()._setup_model() self._create_aliases() # Running mean and running var self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"]) self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"]) # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": # automatically set target entropy if needed self.target_entropy = else: # Force conversion # this will also throw an error for unexpected string self.target_entropy = float(self.target_entropy) # The entropy coefficient or entropy can be learned automatically # see Automating Entropy Adjustment for Maximum Entropy RL section # of if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"): # Default initial value of ent_coef when learned init_value = 1.0 if "_" in self.ent_coef: init_value = float(self.ent_coef.split("_")[1]) assert init_value > 0.0, "The initial value of ent_coef must be greater than 0" # Note: we optimize the log of the entropy coeff which is slightly different from the paper # as discussed in self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True) self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1)) else: # Force conversion to float # this will throw an error if a malformed string (different from 'auto') # is passed self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device) def _create_aliases(self) -> None: = self.critic = self.policy.critic self.critic_target = self.policy.critic_target
[docs] def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update optimizers learning rate optimizers = [, self.critic.optimizer] if self.ent_coef_optimizer is not None: optimizers += [self.ent_coef_optimizer] # Update learning rate according to lr schedule self._update_learning_rate(optimizers) ent_coef_losses, ent_coefs = [], [] actor_losses, critic_losses = [], [] for gradient_step in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # We need to sample because `log_std` may have changed between two gradient steps if self.use_sde: # Action by the current actor for the sampled state actions_pi, log_prob = log_prob = log_prob.reshape(-1, 1) ent_coef_loss = None if self.ent_coef_optimizer is not None: # Important: detach the variable from the graph # so we don't change it with other losses # see ent_coef = th.exp(self.log_ent_coef.detach()) ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() ent_coef_losses.append(ent_coef_loss.item()) else: ent_coef = self.ent_coef_tensor ent_coefs.append(ent_coef.item()) self.replay_buffer.ent_coef = ent_coef.item() # Optimize entropy coefficient, also called # entropy temperature or alpha in the paper if ent_coef_loss is not None: self.ent_coef_optimizer.zero_grad() ent_coef_loss.backward() self.ent_coef_optimizer.step() with th.no_grad(): # Select action according to policy next_actions, next_log_prob = # Compute and cut quantiles at the next state # batch x nets x quantiles next_quantiles = self.critic_target(replay_data.next_observations, next_actions) # Sort and drop top k quantiles to control overestimation. n_target_quantiles = self.critic.quantiles_total - self.top_quantiles_to_drop_per_net * self.critic.n_critics next_quantiles, _ = th.sort(next_quantiles.reshape(batch_size, -1)) next_quantiles = next_quantiles[:, :n_target_quantiles] # td error + entropy term target_quantiles = next_quantiles - ent_coef * next_log_prob.reshape(-1, 1) target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_quantiles # Make target_quantiles broadcastable to (batch_size, n_critics, n_target_quantiles). target_quantiles.unsqueeze_(dim=1) # Get current Quantile estimates using action from the replay buffer current_quantiles = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss, not summing over the quantile dimension as in the paper. critic_loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=False) critic_losses.append(critic_loss.item()) # Optimize the critic self.critic.optimizer.zero_grad() critic_loss.backward() self.critic.optimizer.step() # Compute actor loss qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True) actor_loss = (ent_coef * log_prob - qf_pi).mean() actor_losses.append(actor_loss.item()) # Optimize the actor actor_loss.backward() # Update target networks if gradient_step % self.target_update_interval == 0: polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau) # Copy running stats, see polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/ent_coef", np.mean(ent_coefs)) self.logger.record("train/actor_loss", np.mean(actor_losses)) self.logger.record("train/critic_loss", np.mean(critic_losses)) if len(ent_coef_losses) > 0: self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
[docs] def learn( self: TQCSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, eval_env: Optional[GymEnv] = None, eval_freq: int = -1, n_eval_episodes: int = 5, tb_log_name: str = "TQC", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> TQCSelf: return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name, eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, )
def _excluded_save_params(self) -> List[str]: # Exclude aliases return super()._excluded_save_params() + ["actor", "critic", "critic_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] if self.ent_coef_optimizer is not None: saved_pytorch_variables = ["log_ent_coef"] state_dicts.append("ent_coef_optimizer") else: saved_pytorch_variables = ["ent_coef_tensor"] return state_dicts, saved_pytorch_variables