Examples¶
TQC¶
Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
from sb3_contrib import TQC
model = TQC("MlpPolicy", "Pendulum-v1", top_quantiles_to_drop_per_net=2, verbose=1)
model.learn(total_timesteps=10_000, log_interval=4)
model.save("tqc_pendulum")
QR-DQN¶
Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.
from sb3_contrib import QRDQN
policy_kwargs = dict(n_quantiles=50)
model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10_000, log_interval=4)
model.save("qrdqn_cartpole")
MaskablePPO¶
Train a PPO with invalid action masking agent on a toy environment.
from sb3_contrib import MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
model = MaskablePPO("MlpPolicy", env, verbose=1)
model.learn(5000)
model.save("maskable_toy_env")
TRPO¶
Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environment.
from sb3_contrib import TRPO
model = TRPO("MlpPolicy", "Pendulum-v1", gamma=0.9, verbose=1)
model.learn(total_timesteps=100_000, log_interval=4)
model.save("trpo_pendulum")
ARS¶
Train an agent using Augmented Random Search (ARS) agent on the Pendulum environment
from sb3_contrib import ARS
model = ARS("LinearPolicy", "Pendulum-v1", verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("ars_pendulum")
RecurrentPPO¶
Train a PPO agent with a recurrent policy on the CartPole environment.
Note
It is particularly important to pass the lstm_states
and episode_start
argument to the predict()
method,
so the cell and hidden states of the LSTM are correctly updated.
import numpy as np
from sb3_contrib import RecurrentPPO
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
model.learn(5000)
env = model.get_env()
obs = env.reset()
# cell and hidden state of the LSTM
lstm_states = None
num_envs = 1
# Episode start signals are used to reset the lstm states
episode_starts = np.ones((num_envs,), dtype=bool)
while True:
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
obs, rewards, dones, info = env.step(action)
episode_starts = dones
env.render()