TRPO¶
Trust Region Policy Optimization (TRPO) is an iterative approach for optimizing policies with guaranteed monotonic improvement.
Available Policies
Notes¶
Original paper: https://arxiv.org/abs/1502.05477
OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/
Can I use?¶
Recurrent policies: ❌
Multi processing: ✔️
Gym spaces:
Space |
Action |
Observation |
---|---|---|
Discrete |
✔️ |
✔️ |
Box |
✔️ |
✔️ |
MultiDiscrete |
✔️ |
✔️ |
MultiBinary |
✔️ |
✔️ |
Dict |
❌ |
✔️ |
Example¶
import gymnasium as gym
import numpy as np
from sb3_contrib import TRPO
env = gym.make("Pendulum-v1", render_mode="human")
model = TRPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000, log_interval=4)
model.save("trpo_pendulum")
del model # remove to demonstrate saving and loading
model = TRPO.load("trpo_pendulum")
obs, _ = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
env.render()
if terminated or truncated:
obs, _ = env.reset()
Results¶
Result on the MuJoCo benchmark (1M steps on -v3
envs with MuJoCo v2.1.0) using 3 seeds.
The complete learning curves are available in the associated PR.
Environments |
TRPO |
---|---|
HalfCheetah |
1803 +/- 46 |
Ant |
3554 +/- 591 |
Hopper |
3372 +/- 215 |
Walker2d |
4502 +/- 234 |
Swimmer |
359 +/- 2 |
How to replicate the results?¶
Clone RL-Zoo and checkout the branch feat/trpo
:
git clone https://github.com/cyprienc/rl-baselines3-zoo
cd rl-baselines3-zoo/
Run the benchmark (replace $ENV_ID
by the envs mentioned above):
python train.py --algo tqc --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000
Plot the results:
python scripts/all_plots.py -a trpo -e HalfCheetah Ant Hopper Walker2d Swimmer -f logs/ -o logs/trpo_results
python scripts/plot_from_file.py -i logs/trpo_results.pkl -latex -l TRPO
Parameters¶
TRPO Policies¶
- class stable_baselines3.common.policies.ActorCriticPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.Tanh'>, ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, use_expln=False, squash_output=False, features_extractor_class=<class 'stable_baselines3.common.torch_layers.FlattenExtractor'>, features_extractor_kwargs=None, share_features_extractor=True, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None)[source]
Policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes.
- Parameters:
observation_space (
Space
) – Observation spaceaction_space (
Space
) – Action spacelr_schedule (
Callable
[[float
],float
]) – Learning rate schedule (could be constant)net_arch (
Union
[List
[int
],Dict
[str
,List
[int
]],None
]) – The specification of the policy and value networks.activation_fn (
Type
[Module
]) – Activation functionortho_init (
bool
) – Whether to use or not orthogonal initializationuse_sde (
bool
) – Whether to use State Dependent Exploration or notlog_std_init (
float
) – Initial value for the log standard deviationfull_std (
bool
) – Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDEuse_expln (
bool
) – Useexpln()
function instead ofexp()
to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice,exp()
is usually enough.squash_output (
bool
) – Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE.features_extractor_class (
Type
[BaseFeaturesExtractor
]) – Features extractor to use.features_extractor_kwargs (
Optional
[Dict
[str
,Any
]]) – Keyword arguments to pass to the features extractor.share_features_extractor (
bool
) – If True, the features extractor is shared between the policy and value networks.normalize_images (
bool
) – Whether to normalize images or not, dividing by 255.0 (True by default)optimizer_class (
Type
[Optimizer
]) – The optimizer to use,th.optim.Adam
by defaultoptimizer_kwargs (
Optional
[Dict
[str
,Any
]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer
- evaluate_actions(obs, actions)[source]
Evaluate actions according to the current policy, given the observations.
- Parameters:
obs (
Tensor
) – Observationactions (
Tensor
) – Actions
- Return type:
Tuple
[Tensor
,Tensor
,Optional
[Tensor
]]- Returns:
estimated value, log likelihood of taking those actions and entropy of the action distribution.
- extract_features(obs)[source]
Preprocess the observation if needed and extract features.
- Parameters:
obs (
Tensor
) – Observation- Return type:
Union
[Tensor
,Tuple
[Tensor
,Tensor
]]- Returns:
the output of the features extractor(s)
- forward(obs, deterministic=False)[source]
Forward pass in all the networks (actor and critic)
- Parameters:
obs (
Tensor
) – Observationdeterministic (
bool
) – Whether to sample or use deterministic actions
- Return type:
Tuple
[Tensor
,Tensor
,Tensor
]- Returns:
action, value and log probability of the action
- get_distribution(obs)[source]
Get the current policy distribution given the observations.
- Parameters:
obs (
Tensor
) –- Return type:
Distribution
- Returns:
the action distribution.
- predict_values(obs)[source]
Get the estimated values according to the current policy given the observations.
- Parameters:
obs (
Tensor
) – Observation- Return type:
Tensor
- Returns:
the estimated values.
- reset_noise(n_envs=1)[source]
Sample new weights for the exploration matrix.
- Parameters:
n_envs (
int
) –- Return type:
None
- class stable_baselines3.common.policies.ActorCriticCnnPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.Tanh'>, ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, use_expln=False, squash_output=False, features_extractor_class=<class 'stable_baselines3.common.torch_layers.NatureCNN'>, features_extractor_kwargs=None, share_features_extractor=True, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None)[source]
CNN policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes.
- Parameters:
observation_space (
Space
) – Observation spaceaction_space (
Space
) – Action spacelr_schedule (
Callable
[[float
],float
]) – Learning rate schedule (could be constant)net_arch (
Union
[List
[int
],Dict
[str
,List
[int
]],None
]) – The specification of the policy and value networks.activation_fn (
Type
[Module
]) – Activation functionortho_init (
bool
) – Whether to use or not orthogonal initializationuse_sde (
bool
) – Whether to use State Dependent Exploration or notlog_std_init (
float
) – Initial value for the log standard deviationfull_std (
bool
) – Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDEuse_expln (
bool
) – Useexpln()
function instead ofexp()
to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice,exp()
is usually enough.squash_output (
bool
) – Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE.features_extractor_class (
Type
[BaseFeaturesExtractor
]) – Features extractor to use.features_extractor_kwargs (
Optional
[Dict
[str
,Any
]]) – Keyword arguments to pass to the features extractor.share_features_extractor (
bool
) – If True, the features extractor is shared between the policy and value networks.normalize_images (
bool
) – Whether to normalize images or not, dividing by 255.0 (True by default)optimizer_class (
Type
[Optimizer
]) – The optimizer to use,th.optim.Adam
by defaultoptimizer_kwargs (
Optional
[Dict
[str
,Any
]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer
- class stable_baselines3.common.policies.MultiInputActorCriticPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.Tanh'>, ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, use_expln=False, squash_output=False, features_extractor_class=<class 'stable_baselines3.common.torch_layers.CombinedExtractor'>, features_extractor_kwargs=None, share_features_extractor=True, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None)[source]
MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes.
- Parameters:
observation_space (
Dict
) – Observation space (Tuple)action_space (
Space
) – Action spacelr_schedule (
Callable
[[float
],float
]) – Learning rate schedule (could be constant)net_arch (
Union
[List
[int
],Dict
[str
,List
[int
]],None
]) – The specification of the policy and value networks.activation_fn (
Type
[Module
]) – Activation functionortho_init (
bool
) – Whether to use or not orthogonal initializationuse_sde (
bool
) – Whether to use State Dependent Exploration or notlog_std_init (
float
) – Initial value for the log standard deviationfull_std (
bool
) – Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDEuse_expln (
bool
) – Useexpln()
function instead ofexp()
to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice,exp()
is usually enough.squash_output (
bool
) – Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE.features_extractor_class (
Type
[BaseFeaturesExtractor
]) – Uses the CombinedExtractorfeatures_extractor_kwargs (
Optional
[Dict
[str
,Any
]]) – Keyword arguments to pass to the features extractor.share_features_extractor (
bool
) – If True, the features extractor is shared between the policy and value networks.normalize_images (
bool
) – Whether to normalize images or not, dividing by 255.0 (True by default)optimizer_class (
Type
[Optimizer
]) – The optimizer to use,th.optim.Adam
by defaultoptimizer_kwargs (
Optional
[Dict
[str
,Any
]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer