from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import gym
import numpy as np
import torch as th
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update

from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy

QRDQNSelf = TypeVar("QRDQNSelf", bound="QRDQN")

[docs]class QRDQN(OffPolicyAlgorithm): """ Quantile Regression Deep Q-Network (QR-DQN) Paper: Default hyperparameters are taken from the paper and are tuned for Atari games. :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer :param learning_starts: how many steps of the model to collect transitions for before learning starts :param batch_size: Minibatch size for each gradient update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update :param gamma: the discount factor :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit like ``(5, "step")`` or ``(2, "episode")``. :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq`` and ``n_episodes_rollout``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). If ``None``, it will be automatically selected. :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See :param target_update_interval: update the target network every ``target_update_interval`` environment steps. :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced :param exploration_initial_eps: initial value of random action probability :param exploration_final_eps: final value of random action probability :param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping) :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically (Only available when passing string for the environment). Caution, this parameter is deprecated and will be removed in the future. :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ policy_aliases: Dict[str, Type[BasePolicy]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } def __init__( self, policy: Union[str, Type[QRDQNPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 5e-5, buffer_size: int = 1000000, # 1e6 learning_starts: int = 50000, batch_size: Optional[int] = 32, tau: float = 1.0, gamma: float = 0.99, train_freq: int = 4, gradient_steps: int = 1, replay_buffer_class: Optional[ReplayBuffer] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, target_update_interval: int = 10000, exploration_fraction: float = 0.005, exploration_initial_eps: float = 1.0, exploration_final_eps: float = 0.01, max_grad_norm: Optional[float] = None, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): super().__init__( policy, env, learning_rate, buffer_size, learning_starts, batch_size, tau, gamma, train_freq, gradient_steps, action_noise=None, # No action noise replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, device=device, create_eval_env=create_eval_env, seed=seed, sde_support=False, optimize_memory_usage=optimize_memory_usage, supported_action_spaces=(gym.spaces.Discrete,), support_multi_env=True, ) self.exploration_initial_eps = exploration_initial_eps self.exploration_final_eps = exploration_final_eps self.exploration_fraction = exploration_fraction self.target_update_interval = target_update_interval self.max_grad_norm = max_grad_norm # "epsilon" for the epsilon-greedy exploration self.exploration_rate = 0.0 # Linear schedule will be defined in `_setup_model()` self.exploration_schedule = None self.quantile_net, self.quantile_net_target = None, None if "optimizer_class" not in self.policy_kwargs: self.policy_kwargs["optimizer_class"] = th.optim.Adam # Proposed in the QR-DQN paper where `batch_size = 32` self.policy_kwargs["optimizer_kwargs"] = dict(eps=0.01 / batch_size) if _init_setup_model: self._setup_model() def _setup_model(self) -> None: super()._setup_model() self._create_aliases() # Copy running stats, see self.batch_norm_stats = get_parameters_by_name(self.quantile_net, ["running_"]) self.batch_norm_stats_target = get_parameters_by_name(self.quantile_net_target, ["running_"]) self.exploration_schedule = get_linear_fn( self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction ) def _create_aliases(self) -> None: self.quantile_net = self.policy.quantile_net self.quantile_net_target = self.policy.quantile_net_target self.n_quantiles = self.policy.n_quantiles def _on_step(self) -> None: """ Update the exploration rate and target network if needed. This method is called in ``collect_rollouts()`` after each step in the environment. """ if self.num_timesteps % self.target_update_interval == 0: polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau) # Copy running stats, see polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) self.logger.record("rollout/exploration_rate", self.exploration_rate)
[docs] def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update learning rate according to schedule self._update_learning_rate(self.policy.optimizer) losses = [] for _ in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) with th.no_grad(): # Compute the quantiles of next observation next_quantiles = self.quantile_net_target(replay_data.next_observations) # Compute the greedy actions which maximize the next Q values next_greedy_actions = next_quantiles.mean(dim=1, keepdim=True).argmax(dim=2, keepdim=True) # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1) next_greedy_actions = next_greedy_actions.expand(batch_size, self.n_quantiles, 1) # Follow greedy policy: use the one with the highest Q values next_quantiles = next_quantiles.gather(dim=2, index=next_greedy_actions).squeeze(dim=2) # 1-step TD target target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles # Get current quantile estimates current_quantiles = self.quantile_net(replay_data.observations) # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1). actions = replay_data.actions[..., None].long().expand(batch_size, self.n_quantiles, 1) # Retrieve the quantiles for the actions from the replay buffer current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2) # Compute Quantile Huber loss, summing over a quantile dimension as in the paper. loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True) losses.append(loss.item()) # Optimize the policy self.policy.optimizer.zero_grad() loss.backward() # Clip gradient norm if self.max_grad_norm is not None: th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() # Increase update counter self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/loss", np.mean(losses))
[docs] def predict( self, observation: np.ndarray, state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation :param state: The last hidden states (can be None, used in recurrent policies) :param episode_start: The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. :return: the model's action and the next hidden state (used in recurrent policies) """ if not deterministic and np.random.rand() < self.exploration_rate: if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): if isinstance(self.observation_space, gym.spaces.Dict): n_batch = observation[list(observation.keys())[0]].shape[0] else: n_batch = observation.shape[0] action = np.array([self.action_space.sample() for _ in range(n_batch)]) else: action = np.array(self.action_space.sample()) else: action, state = self.policy.predict(observation, state, episode_start, deterministic) return action, state
[docs] def learn( self: QRDQNSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, eval_env: Optional[GymEnv] = None, eval_freq: int = -1, n_eval_episodes: int = 5, tb_log_name: str = "QRDQN", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> QRDQNSelf: return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name, eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, )
def _excluded_save_params(self) -> List[str]: return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "policy.optimizer"] return state_dicts, []