Source code for sb3_contrib.tqc.tqc

from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update

from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.tqc.policies import Actor, CnnPolicy, Critic, MlpPolicy, MultiInputPolicy, TQCPolicy

SelfTQC = TypeVar("SelfTQC", bound="TQC")


[docs]class TQC(OffPolicyAlgorithm): """ Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics. Paper: https://arxiv.org/abs/2005.04269 This implementation uses SB3 SAC implementation as base. :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer :param learning_starts: how many steps of the model to collect transitions for before learning starts :param batch_size: Minibatch size for each gradient update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) :param gamma: the discount factor :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit like ``(5, "step")`` or ``(2, "episode")``. :param gradient_steps: How many gradient update after each step :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). If ``None``, it will be automatically selected. :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 :param ent_coef: Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value) :param target_update_interval: update the target network every ``target_network_update_freq`` gradient steps. :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) :param top_quantiles_to_drop_per_net: Number of quantiles to drop per network :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts) :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } policy: TQCPolicy actor: Actor critic: Critic critic_target: Critic def __init__( self, policy: Union[str, Type[TQCPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Callable] = 3e-4, buffer_size: int = 1000000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, train_freq: Union[int, Tuple[int, str]] = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, ent_coef: Union[str, float] = "auto", target_update_interval: int = 1, target_entropy: Union[str, float] = "auto", top_quantiles_to_drop_per_net: int = 2, use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): super().__init__( policy, env, learning_rate, buffer_size, learning_starts, batch_size, tau, gamma, train_freq, gradient_steps, action_noise=action_noise, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, policy_kwargs=policy_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, verbose=verbose, device=device, seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, optimize_memory_usage=optimize_memory_usage, supported_action_spaces=(spaces.Box,), support_multi_env=True, ) self.target_entropy = target_entropy self.log_ent_coef = None # type: Optional[th.Tensor] # Entropy coefficient / Entropy temperature # Inverse of the reward scale self.ent_coef = ent_coef self.target_update_interval = target_update_interval self.ent_coef_optimizer: Optional[th.optim.Adam] = None self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net if _init_setup_model: self._setup_model() def _setup_model(self) -> None: super()._setup_model() self._create_aliases() # Running mean and running var self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"]) self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"]) # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": # automatically set target entropy if needed self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore else: # Force conversion # this will also throw an error for unexpected string self.target_entropy = float(self.target_entropy) # The entropy coefficient or entropy can be learned automatically # see Automating Entropy Adjustment for Maximum Entropy RL section # of https://arxiv.org/abs/1812.05905 if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"): # Default initial value of ent_coef when learned init_value = 1.0 if "_" in self.ent_coef: init_value = float(self.ent_coef.split("_")[1]) assert init_value > 0.0, "The initial value of ent_coef must be greater than 0" # Note: we optimize the log of the entropy coeff which is slightly different from the paper # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True) self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1)) else: # Force conversion to float # this will throw an error if a malformed string (different from 'auto') # is passed self.ent_coef_tensor = th.tensor(float(self.ent_coef), device=self.device) def _create_aliases(self) -> None: self.actor = self.policy.actor self.critic = self.policy.critic self.critic_target = self.policy.critic_target
[docs] def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update optimizers learning rate optimizers = [self.actor.optimizer, self.critic.optimizer] if self.ent_coef_optimizer is not None: optimizers += [self.ent_coef_optimizer] # Update learning rate according to lr schedule self._update_learning_rate(optimizers) ent_coef_losses, ent_coefs = [], [] actor_losses, critic_losses = [], [] for gradient_step in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] # We need to sample because `log_std` may have changed between two gradient steps if self.use_sde: self.actor.reset_noise() # Action by the current actor for the sampled state actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations) log_prob = log_prob.reshape(-1, 1) ent_coef_loss = None if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: # Important: detach the variable from the graph # so we don't change it with other losses # see https://github.com/rail-berkeley/softlearning/issues/60 ent_coef = th.exp(self.log_ent_coef.detach()) ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() ent_coef_losses.append(ent_coef_loss.item()) else: ent_coef = self.ent_coef_tensor ent_coefs.append(ent_coef.item()) # Optimize entropy coefficient, also called # entropy temperature or alpha in the paper if ent_coef_loss is not None and self.ent_coef_optimizer is not None: self.ent_coef_optimizer.zero_grad() ent_coef_loss.backward() self.ent_coef_optimizer.step() with th.no_grad(): # Select action according to policy next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) # Compute and cut quantiles at the next state # batch x nets x quantiles next_quantiles = self.critic_target(replay_data.next_observations, next_actions) # Sort and drop top k quantiles to control overestimation. n_target_quantiles = self.critic.quantiles_total - self.top_quantiles_to_drop_per_net * self.critic.n_critics next_quantiles, _ = th.sort(next_quantiles.reshape(batch_size, -1)) next_quantiles = next_quantiles[:, :n_target_quantiles] # td error + entropy term target_quantiles = next_quantiles - ent_coef * next_log_prob.reshape(-1, 1) target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_quantiles # Make target_quantiles broadcastable to (batch_size, n_critics, n_target_quantiles). target_quantiles.unsqueeze_(dim=1) # Get current Quantile estimates using action from the replay buffer current_quantiles = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss, not summing over the quantile dimension as in the paper. critic_loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=False) critic_losses.append(critic_loss.item()) # Optimize the critic self.critic.optimizer.zero_grad() critic_loss.backward() self.critic.optimizer.step() # Compute actor loss qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True) actor_loss = (ent_coef * log_prob - qf_pi).mean() actor_losses.append(actor_loss.item()) # Optimize the actor self.actor.optimizer.zero_grad() actor_loss.backward() self.actor.optimizer.step() # Update target networks if gradient_step % self.target_update_interval == 0: polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau) # Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/ent_coef", np.mean(ent_coefs)) self.logger.record("train/actor_loss", np.mean(actor_losses)) self.logger.record("train/critic_loss", np.mean(critic_losses)) if len(ent_coef_losses) > 0: self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
[docs] def learn( self: SelfTQC, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "TQC", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfTQC: return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, )
def _excluded_save_params(self) -> List[str]: # Exclude aliases return super()._excluded_save_params() + ["actor", "critic", "critic_target"] # noqa: RUF005 def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] if self.ent_coef_optimizer is not None: saved_pytorch_variables = ["log_ent_coef"] state_dicts.append("ent_coef_optimizer") else: saved_pytorch_variables = ["ent_coef_tensor"] return state_dicts, saved_pytorch_variables