TRPO

Trust Region Policy Optimization (TRPO) is an iterative approach for optimizing policies with guaranteed monotonic improvement.

Available Policies

Notes

Can I use?

  • Recurrent policies: ❌

  • Multi processing: ✔️

  • Gym spaces:

Space

Action

Observation

Discrete

✔️

✔️

Box

✔️

✔️

MultiDiscrete

✔️

✔️

MultiBinary

✔️

✔️

Dict

✔️

Example

import gym
import numpy as np

from sb3_contrib import TRPO

env = gym.make("Pendulum-v1")

model = TRPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("trpo_pendulum")

del model # remove to demonstrate saving and loading

model = TRPO.load("trpo_pendulum")

obs = env.reset()
while True:
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()

Results

Result on the MuJoCo benchmark (1M steps on -v3 envs with MuJoCo v2.1.0) using 3 seeds. The complete learning curves are available in the associated PR.

Environments

TRPO

HalfCheetah

1803 +/- 46

Ant

3554 +/- 591

Hopper

3372 +/- 215

Walker2d

4502 +/- 234

Swimmer

359 +/- 2

How to replicate the results?

Clone RL-Zoo and checkout the branch feat/trpo:

git clone https://github.com/cyprienc/rl-baselines3-zoo
cd rl-baselines3-zoo/

Run the benchmark (replace $ENV_ID by the envs mentioned above):

python train.py --algo tqc --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000

Plot the results:

python scripts/all_plots.py -a trpo -e HalfCheetah Ant Hopper Walker2d Swimmer -f logs/ -o logs/trpo_results
python scripts/plot_from_file.py -i logs/trpo_results.pkl -latex -l TRPO

Parameters

TRPO Policies

class stable_baselines3.common.policies.ActorCriticPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.Tanh'>, ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, use_expln=False, squash_output=False, features_extractor_class=<class 'stable_baselines3.common.torch_layers.FlattenExtractor'>, features_extractor_kwargs=None, share_features_extractor=True, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None)[source]

Policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes.

Parameters:
  • observation_space (Space) – Observation space

  • action_space (Space) – Action space

  • lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)

  • net_arch (Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None]) – The specification of the policy and value networks.

  • activation_fn (Type[Module]) – Activation function

  • ortho_init (bool) – Whether to use or not orthogonal initialization

  • use_sde (bool) – Whether to use State Dependent Exploration or not

  • log_std_init (float) – Initial value for the log standard deviation

  • full_std (bool) – Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE

  • use_expln (bool) – Use expln() function instead of exp() to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, exp() is usually enough.

  • squash_output (bool) – Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE.

  • features_extractor_class (Type[BaseFeaturesExtractor]) – Features extractor to use.

  • features_extractor_kwargs (Optional[Dict[str, Any]]) – Keyword arguments to pass to the features extractor.

  • share_features_extractor (bool) – If True, the features extractor is shared between the policy and value networks.

  • normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)

  • optimizer_class (Type[Optimizer]) – The optimizer to use, th.optim.Adam by default

  • optimizer_kwargs (Optional[Dict[str, Any]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer

evaluate_actions(obs, actions)[source]

Evaluate actions according to the current policy, given the observations.

Parameters:
  • obs (Tensor) – Observation

  • actions (Tensor) – Actions

Return type:

Tuple[Tensor, Tensor, Optional[Tensor]]

Returns:

estimated value, log likelihood of taking those actions and entropy of the action distribution.

extract_features(obs)[source]

Preprocess the observation if needed and extract features.

Parameters:

obs (Tensor) – Observation

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

Returns:

the output of the features extractor(s)

forward(obs, deterministic=False)[source]

Forward pass in all the networks (actor and critic)

Parameters:
  • obs (Tensor) – Observation

  • deterministic (bool) – Whether to sample or use deterministic actions

Return type:

Tuple[Tensor, Tensor, Tensor]

Returns:

action, value and log probability of the action

get_distribution(obs)[source]

Get the current policy distribution given the observations.

Parameters:

obs (Tensor) –

Return type:

Distribution

Returns:

the action distribution.

predict_values(obs)[source]

Get the estimated values according to the current policy given the observations.

Parameters:

obs (Tensor) – Observation

Return type:

Tensor

Returns:

the estimated values.

reset_noise(n_envs=1)[source]

Sample new weights for the exploration matrix.

Parameters:

n_envs (int) –

Return type:

None

class stable_baselines3.common.policies.ActorCriticCnnPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.Tanh'>, ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, use_expln=False, squash_output=False, features_extractor_class=<class 'stable_baselines3.common.torch_layers.NatureCNN'>, features_extractor_kwargs=None, share_features_extractor=True, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None)[source]

CNN policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes.

Parameters:
  • observation_space (Space) – Observation space

  • action_space (Space) – Action space

  • lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)

  • net_arch (Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None]) – The specification of the policy and value networks.

  • activation_fn (Type[Module]) – Activation function

  • ortho_init (bool) – Whether to use or not orthogonal initialization

  • use_sde (bool) – Whether to use State Dependent Exploration or not

  • log_std_init (float) – Initial value for the log standard deviation

  • full_std (bool) – Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE

  • use_expln (bool) – Use expln() function instead of exp() to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, exp() is usually enough.

  • squash_output (bool) – Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE.

  • features_extractor_class (Type[BaseFeaturesExtractor]) – Features extractor to use.

  • features_extractor_kwargs (Optional[Dict[str, Any]]) – Keyword arguments to pass to the features extractor.

  • share_features_extractor (bool) – If True, the features extractor is shared between the policy and value networks.

  • normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)

  • optimizer_class (Type[Optimizer]) – The optimizer to use, th.optim.Adam by default

  • optimizer_kwargs (Optional[Dict[str, Any]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer

class stable_baselines3.common.policies.MultiInputActorCriticPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.Tanh'>, ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, use_expln=False, squash_output=False, features_extractor_class=<class 'stable_baselines3.common.torch_layers.CombinedExtractor'>, features_extractor_kwargs=None, share_features_extractor=True, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None)[source]

MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes.

Parameters:
  • observation_space (Dict) – Observation space (Tuple)

  • action_space (Space) – Action space

  • lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)

  • net_arch (Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None]) – The specification of the policy and value networks.

  • activation_fn (Type[Module]) – Activation function

  • ortho_init (bool) – Whether to use or not orthogonal initialization

  • use_sde (bool) – Whether to use State Dependent Exploration or not

  • log_std_init (float) – Initial value for the log standard deviation

  • full_std (bool) – Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE

  • use_expln (bool) – Use expln() function instead of exp() to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, exp() is usually enough.

  • squash_output (bool) – Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE.

  • features_extractor_class (Type[BaseFeaturesExtractor]) – Uses the CombinedExtractor

  • features_extractor_kwargs (Optional[Dict[str, Any]]) – Keyword arguments to pass to the features extractor.

  • share_features_extractor (bool) – If True, the features extractor is shared between the policy and value networks.

  • normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)

  • optimizer_class (Type[Optimizer]) – The optimizer to use, th.optim.Adam by default

  • optimizer_kwargs (Optional[Dict[str, Any]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer