Source code for sb3_contrib.ppo_mask.ppo_mask

import sys
import time
import warnings
from collections import deque
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union

import gym
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
from torch.nn import functional as F

from sb3_contrib.common.maskable.buffers import MaskableDictRolloutBuffer, MaskableRolloutBuffer
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy

MaskablePPOSelf = TypeVar("MaskablePPOSelf", bound="MaskablePPO")


[docs]class MaskablePPO(OnPolicyAlgorithm): """ Proximal Policy Optimization algorithm (PPO) (clip version) with Invalid Action Masking. Based on the original Stable Baselines 3 implementation. Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html Background on Invalid Action Masking: https://arxiv.org/abs/2006.14171 :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0) :param n_steps: The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) :param batch_size: Minibatch size :param n_epochs: Number of epoch when optimizing the surrogate loss :param gamma: Discount factor :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator :param clip_range: Clipping parameter, it can be a function of the current progress remaining (from 1 to 0). :param clip_range_vf: Clipping parameter for the value function, it can be a function of the current progress remaining (from 1 to 0). This is a parameter specific to the OpenAI implementation. If None is passed (default), no clipping will be done on the value function. IMPORTANT: this clipping depends on the reward scaling. :param normalize_advantage: Whether to normalize or not the advantage :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation :param max_grad_norm: The maximum value for the gradient clipping :param target_kl: Limit the KL divergence between updates, because the clipping is not enough to prevent large update see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) By default, there is no limit on the kl div. :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically (Only available when passing string for the environment). Caution, this parameter is deprecated and will be removed in the future. :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ policy_aliases: Dict[str, Type[BasePolicy]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } def __init__( self, policy: Union[str, Type[MaskableActorCriticPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 2048, batch_size: Optional[int] = 64, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, clip_range: Union[float, Schedule] = 0.2, clip_range_vf: Union[None, float, Schedule] = None, normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, target_kl: Optional[float] = None, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): super().__init__( policy, env, learning_rate=learning_rate, n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, use_sde=False, sde_sample_freq=-1, tensorboard_log=tensorboard_log, create_eval_env=create_eval_env, policy_kwargs=policy_kwargs, verbose=verbose, seed=seed, device=device, _init_setup_model=False, supported_action_spaces=( spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary, ), ) self.batch_size = batch_size self.n_epochs = n_epochs self.clip_range = clip_range self.clip_range_vf = clip_range_vf self.normalize_advantage = normalize_advantage self.target_kl = target_kl if _init_setup_model: self._setup_model() def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) buffer_cls = ( MaskableDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else MaskableRolloutBuffer ) self.policy = self.policy_class( self.observation_space, self.action_space, self.lr_schedule, **self.policy_kwargs, # pytype:disable=not-instantiable ) self.policy = self.policy.to(self.device) if not isinstance(self.policy, MaskableActorCriticPolicy): raise ValueError("Policy must subclass MaskableActorCriticPolicy") self.rollout_buffer = buffer_cls( self.n_steps, self.observation_space, self.action_space, self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, ) # Initialize schedules for policy/value clipping self.clip_range = get_schedule_fn(self.clip_range) if self.clip_range_vf is not None: if isinstance(self.clip_range_vf, (float, int)): assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" self.clip_range_vf = get_schedule_fn(self.clip_range_vf) def _init_callback( self, callback: MaybeCallback, eval_env: Optional[VecEnv] = None, eval_freq: int = 10000, n_eval_episodes: int = 5, log_path: Optional[str] = None, use_masking: bool = True, progress_bar: bool = False, ) -> BaseCallback: """ :param callback: Callback(s) called at every step with state of the algorithm. :param eval_env: Environment to use for evaluation. Caution, this parameter is deprecated and will be removed in the future. Please use `MaskableEvalCallback` or a custom Callback instead. :param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little). Caution, this parameter is deprecated and will be removed in the future. Please use `MaskableEvalCallback` or a custom Callback instead. :param n_eval_episodes: How many episodes to play per evaluation :param log_path: Path to a folder where the evaluations will be saved :param use_masking: Whether or not to use invalid action masks during evaluation :param progress_bar: Display a progress bar using tqdm and rich. :return: A hybrid callback calling `callback` and performing evaluation. """ # Convert a list of callbacks into a callback if isinstance(callback, list): callback = CallbackList(callback) # Convert functional callback to object if not isinstance(callback, BaseCallback): callback = ConvertCallback(callback) # Add progress bar callback if progress_bar: callback = CallbackList([callback, ProgressBarCallback()]) # Create eval callback in charge of the evaluation if eval_env is not None: # Avoid circular import error from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback eval_callback = MaskableEvalCallback( eval_env, best_model_save_path=log_path, log_path=log_path, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, use_masking=use_masking, verbose=self.verbose, ) callback = CallbackList([callback, eval_callback]) callback.init_callback(self) return callback def _setup_learn( self, total_timesteps: int, eval_env: Optional[GymEnv], callback: MaybeCallback = None, eval_freq: int = 10000, n_eval_episodes: int = 5, log_path: Optional[str] = None, reset_num_timesteps: bool = True, tb_log_name: str = "run", use_masking: bool = True, progress_bar: bool = False, ) -> Tuple[int, BaseCallback]: """ Initialize different variables needed for training. :param total_timesteps: The total number of samples (env steps) to train on :param eval_env: Environment to use for evaluation. Caution, this parameter is deprecated and will be removed in the future. Please use `MaskableEvalCallback` or a custom Callback instead. :param callback: Callback(s) called at every step with state of the algorithm. :param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little). Caution, this parameter is deprecated and will be removed in the future. Please use `MaskableEvalCallback` or a custom Callback instead. :param n_eval_episodes: How many episodes to play per evaluation :param log_path: Path to a folder where the evaluations will be saved :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute :param tb_log_name: the name of the run for tensorboard log :param use_masking: Whether or not to use invalid action masks during training :param progress_bar: Display a progress bar using tqdm and rich. :return: """ if eval_env is not None or eval_freq != -1: warnings.warn( "Parameters `eval_env` and `eval_freq` are deprecated and will be removed in the future. " "Please use `MaskableEvalCallback` or a custom Callback instead.", DeprecationWarning, # By setting the `stacklevel` we refer to the initial caller of the deprecated feature. # This causes the the `DepricationWarning` to not be ignored and to be shown to the user. See # https://github.com/DLR-RM/stable-baselines3/pull/1082#discussion_r989842855 for more details. stacklevel=4, ) self.start_time = time.time_ns() if self.ep_info_buffer is None or reset_num_timesteps: # Initialize buffers if they don't exist, or reinitialize if resetting counters self.ep_info_buffer = deque(maxlen=100) self.ep_success_buffer = deque(maxlen=100) if reset_num_timesteps: self.num_timesteps = 0 self._episode_num = 0 else: # Make sure training timesteps are ahead of the internal counter total_timesteps += self.num_timesteps self._total_timesteps = total_timesteps # Avoid resetting the environment when calling ``.learn()`` consecutive times if reset_num_timesteps or self._last_obs is None: self._last_obs = self.env.reset() self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool) # Retrieve unnormalized observation for saving into the buffer if self._vec_normalize_env is not None: self._last_original_obs = self._vec_normalize_env.get_original_obs() if eval_env is not None and self.seed is not None: eval_env.seed(self.seed) eval_env = self._get_eval_env(eval_env) # Configure logger's outputs if no logger was passed if not self._custom_logger: self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps) # Create eval callback if needed callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking, progress_bar) return total_timesteps, callback
[docs] def collect_rollouts( self, env: VecEnv, callback: BaseCallback, rollout_buffer: RolloutBuffer, n_rollout_steps: int, use_masking: bool = True, ) -> bool: """ Collect experiences using the current policy and fill a ``RolloutBuffer``. The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning. This method is largely identical to the implementation found in the parent class. :param env: The training environment :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) :param rollout_buffer: Buffer to fill with rollouts :param n_steps: Number of experiences to collect per environment :param use_masking: Whether or not to use invalid action masks during training :return: True if function returned with at least `n_rollout_steps` collected, False if callback terminated rollout prematurely. """ assert isinstance( rollout_buffer, (MaskableRolloutBuffer, MaskableDictRolloutBuffer) ), "RolloutBuffer doesn't support action masking" assert self._last_obs is not None, "No previous observation was provided" # Switch to eval mode (this affects batch norm / dropout) self.policy.set_training_mode(False) n_steps = 0 action_masks = None rollout_buffer.reset() if use_masking and not is_masking_supported(env): raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper") callback.on_rollout_start() while n_steps < n_rollout_steps: with th.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) # This is the only change related to invalid action masking if use_masking: action_masks = get_action_masks(env) actions, values, log_probs = self.policy(obs_tensor, action_masks=action_masks) actions = actions.cpu().numpy() new_obs, rewards, dones, infos = env.step(actions) self.num_timesteps += env.num_envs # Give access to local variables callback.update_locals(locals()) if callback.on_step() is False: return False self._update_info_buffer(infos) n_steps += 1 if isinstance(self.action_space, spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) # Handle timeout by bootstraping with value function # see GitHub issue #633 for idx, done in enumerate(dones): if ( done and infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False) ): terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] with th.no_grad(): terminal_value = self.policy.predict_values(terminal_obs)[0] rewards[idx] += self.gamma * terminal_value rollout_buffer.add( self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, action_masks=action_masks, ) self._last_obs = new_obs self._last_episode_starts = dones with th.no_grad(): # Compute value for the last timestep # Masking is not needed here, the choice of action doesn't matter. # We only want the value of the current observation. values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) callback.on_rollout_end() return True
[docs] def predict( self, observation: np.ndarray, state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, action_masks: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation :param state: The last hidden states (can be None, used in recurrent policies) :param episode_start: The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. :return: the model's action and the next hidden state (used in recurrent policies) """ return self.policy.predict(observation, state, episode_start, deterministic, action_masks=action_masks)
[docs] def train(self) -> None: """ Update policy using the currently gathered rollout buffer. """ # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range clip_range = self.clip_range(self._current_progress_remaining) # Optional: clip range for the value function if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf(self._current_progress_remaining) entropy_losses = [] pg_losses, value_losses = [], [] clip_fractions = [] continue_training = True # train for n_epochs epochs for epoch in range(self.n_epochs): approx_kl_divs = [] # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions, action_masks=rollout_data.action_masks, ) values = values.flatten() # Normalize advantage advantages = rollout_data.advantages if self.normalize_advantage: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) # clipped surrogate loss policy_loss_1 = advantages * ratio policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() # Logging pg_losses.append(policy_loss.item()) clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: # No clipping values_pred = values else: # Clip the different between old and new value # NOTE: this depends on the reward scaling values_pred = rollout_data.old_values + th.clamp( values - rollout_data.old_values, -clip_range_vf, clip_range_vf ) # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) entropy_losses.append(entropy_loss.item()) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # Calculate approximate form of reverse KL Divergence for early stopping # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 # and Schulman blog: http://joschu.net/blog/kl-approx.html with th.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: continue_training = False if self.verbose >= 1: print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") break # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() if not continue_training: break self._n_updates += self.n_epochs explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs self.logger.record("train/entropy_loss", np.mean(entropy_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/loss", loss.item()) self.logger.record("train/explained_variance", explained_var) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/clip_range", clip_range) if self.clip_range_vf is not None: self.logger.record("train/clip_range_vf", clip_range_vf)
[docs] def learn( self: MaskablePPOSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, eval_env: Optional[GymEnv] = None, eval_freq: int = -1, n_eval_episodes: int = 5, tb_log_name: str = "PPO", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, use_masking: bool = True, progress_bar: bool = False, ) -> MaskablePPOSelf: iteration = 0 total_timesteps, callback = self._setup_learn( total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name, use_masking, progress_bar, ) callback.on_training_start(locals(), globals()) while self.num_timesteps < total_timesteps: continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking) if continue_training is False: break iteration += 1 self._update_current_progress_remaining(self.num_timesteps, total_timesteps) # Display training infos if log_interval is not None and iteration % log_interval == 0: time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/iterations", iteration, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) self.logger.record("time/fps", fps) self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") self.logger.dump(step=self.num_timesteps) self.train() callback.on_training_end() return self