ARS¶
Augmented Random Search (ARS) is a simple reinforcement algorithm that uses a direct random search over policy parameters. It can be surprisingly effective compared to more sophisticated algorithms. In the original paper the authors showed that linear policies trained with ARS were competitive with deep reinforcement learning for the MuJuCo locomotion tasks.
SB3s implementation allows for linear policies without bias or squashing function, it also allows for training MLP policies, which include linear policies with bias and squashing functions as a special case.
Normally one wants to train ARS with several seeds to properly evaluate.
Warning
ARS multi-processing is different from the classic Stable-Baselines3 multi-processing: it runs n environments
in parallel but asynchronously. This asynchronous multi-processing is considered experimental
and does not fully support callbacks: the on_step()
event is called artificially after the evaluation episodes are over.
Available Policies
alias of |
|
alias of |
Notes¶
Original paper: https://arxiv.org/abs/1803.07055
Original Implementation: https://github.com/modestyachts/ARS
Can I use?¶
Recurrent policies: ❌
Multi processing: ✔️ (cf. example)
Gym spaces:
Space |
Action |
Observation |
---|---|---|
Discrete |
✔️ |
✔️ |
Box |
✔️ |
✔️ |
MultiDiscrete |
❌ |
✔️ |
MultiBinary |
❌ |
✔️ |
Dict |
❌ |
❌ |
Example¶
from sb3_contrib import ARS
# Policy can be LinearPolicy or MlpPolicy
model = ARS("LinearPolicy", "Pendulum-v1", verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("ars_pendulum")
With experimental asynchronous multi-processing:
from sb3_contrib import ARS
from sb3_contrib.common.vec_env import AsyncEval
from stable_baselines3.common.env_util import make_vec_env
env_id = "CartPole-v1"
n_envs = 2
model = ARS("LinearPolicy", env_id, n_delta=2, n_top=1, verbose=1)
# Create env for asynchronous evaluation (run in different processes)
async_eval = AsyncEval([lambda: make_vec_env(env_id) for _ in range(n_envs)], model.policy)
model.learn(total_timesteps=200_000, log_interval=4, async_eval=async_eval)
Results¶
Replicating results from the original paper, which used the Mujoco benchmarks. Same parameters from the original paper, using 8 seeds.
Environments |
ARS |
---|---|
HalfCheetah |
4398 +/- 320 |
Swimmer |
241 +/- 51 |
Hopper |
3320 +/- 120 |
How to replicate the results?¶
Clone RL-Zoo and checkout the branch feat/ars
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
git checkout feat/ars
Run the benchmark. The following code snippet trains 8 seeds in parallel
for ENV_ID in Swimmer-v3 HalfCheetah-v3 Hopper-v3
do
for SEED_NUM in {1..8}
do
SEED=$RANDOM
python train.py --algo ars --env $ENV_ID --eval-episodes 10 --eval-freq 10000 -n 20000000 --seed $SEED &
sleep 1
done
wait
done
Plot the results:
python scripts/all_plots.py -a ars -e HalfCheetah Swimmer Hopper -f logs/ -o logs/ars_results -max 20000000
python scripts/plot_from_file.py -i logs/ars_results.pkl -l ARS
Parameters¶
- class sb3_contrib.ars.ARS(policy, env, n_delta=8, n_top=None, learning_rate=0.02, delta_std=0.05, zero_policy=True, alive_bonus_offset=0, n_eval_episodes=1, policy_kwargs=None, tensorboard_log=None, seed=None, verbose=0, device='cpu', _init_setup_model=True)[source]¶
Augmented Random Search: https://arxiv.org/abs/1803.07055
Original implementation: https://github.com/modestyachts/ARS C++/Cuda Implementation: https://github.com/google-research/tiny-differentiable-simulator/ 150 LOC Numpy Implementation: https://github.com/alexis-jacq/numpy_ARS/blob/master/asr.py
- Parameters:
policy (
Union
[str
,Type
[ARSPolicy
]]) – The policy to train, can be an instance ofARSPolicy
, or a string from [“LinearPolicy”, “MlpPolicy”]env (
Union
[Env
,VecEnv
,str
]) – The environment to train on, may be a string if registered with gymn_delta (
int
) – How many random perturbations of the policy to try at each update step.n_top (
Optional
[int
]) – How many of the top delta to use in each update step. Default is n_deltalearning_rate (
Union
[float
,Callable
[[float
],float
]]) – Float or schedule for the step sizedelta_std (
Union
[float
,Callable
[[float
],float
]]) – Float or schedule for the exploration noisezero_policy (
bool
) – Boolean determining if the passed policy should have it’s weights zeroed before training.alive_bonus_offset (
float
) – Constant added to the reward at each step, used to cancel out alive bonuses.n_eval_episodes (
int
) – Number of episodes to evaluate each candidate.policy_kwargs (
Optional
[Dict
[str
,Any
]]) – Keyword arguments to pass to the policy on creationtensorboard_log (
Optional
[str
]) – String with the directory to put tensorboard logs:seed (
Optional
[int
]) – Random seed for the trainingverbose (
int
) – Verbosity level: 0 no output, 1 info, 2 debugdevice (
Union
[device
,str
]) – Torch device to use for training, defaults to “cpu”_init_setup_model (
bool
) – Whether or not to build the network at the creation of the instance
- evaluate_candidates(candidate_weights, callback, async_eval)[source]¶
Evaluate each candidate.
- Parameters:
candidate_weights (
Tensor
) – The candidate weights to be evaluated.callback (
BaseCallback
) – Callback that will be called at each step (or after evaluation in the multiprocess version)async_eval (
Optional
[AsyncEval
]) – The object for asynchronous evaluation of candidates.
- Return type:
Tensor
- Returns:
The episodic return for each candidate.
- get_env()¶
Returns the current environment (can be None if not defined).
- Return type:
Optional
[VecEnv
]- Returns:
The current environment
- get_parameters()¶
Return the parameters of the agent. This includes parameters from different networks, e.g. critics (value functions) and policies (pi functions).
- Return type:
Dict
[str
,Dict
]- Returns:
Mapping of from names of the objects to PyTorch state-dicts.
- get_vec_normalize_env()¶
Return the
VecNormalize
wrapper of the training env if it exists.- Return type:
Optional
[VecNormalize
]- Returns:
The
VecNormalize
env.
- learn(total_timesteps, callback=None, log_interval=1, tb_log_name='ARS', reset_num_timesteps=True, async_eval=None, progress_bar=False)[source]¶
Return a trained model.
- Parameters:
total_timesteps (
int
) – The total number of samples (env steps) to train oncallback (
Union
[None
,Callable
,List
[BaseCallback
],BaseCallback
]) – callback(s) called at every step with state of the algorithm.log_interval (
int
) – The number of timesteps before logging.tb_log_name (
str
) – the name of the run for TensorBoard loggingreset_num_timesteps (
bool
) – whether or not to reset the current timestep number (used in logging)async_eval (
Optional
[AsyncEval
]) – The object for asynchronous evaluation of candidates.progress_bar (
bool
) – Display a progress bar using tqdm and rich.
- Return type:
TypeVar
(SelfARS
, bound= ARS)- Returns:
the trained model
- classmethod load(path, env=None, device='auto', custom_objects=None, print_system_info=False, force_reset=True, **kwargs)¶
Load the model from a zip-file. Warning:
load
re-creates the model from scratch, it does not update it in-place! For an in-place load useset_parameters
instead.- Parameters:
path (
Union
[str
,Path
,BufferedIOBase
]) – path to the file (or a file-like) where to load the agent fromenv (
Union
[Env
,VecEnv
,None
]) – the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environmentdevice (
Union
[device
,str
]) – Device on which the code should run.custom_objects (
Optional
[Dict
[str
,Any
]]) – Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects inkeras.models.load_model
. Useful when you have an object in file that can not be deserialized.print_system_info (
bool
) – Whether to print system info from the saved model and the current system info (useful to debug loading issues)force_reset (
bool
) – Force call toreset()
before training to avoid unexpected behavior. See https://github.com/DLR-RM/stable-baselines3/issues/597kwargs – extra arguments to change the model when loading
- Return type:
TypeVar
(SelfBaseAlgorithm
, bound= BaseAlgorithm)- Returns:
new model instance with loaded parameters
- property logger: Logger¶
Getter for the logger object.
- predict(observation, state=None, episode_start=None, deterministic=False)¶
Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images).
- Parameters:
observation (
Union
[ndarray
,Dict
[str
,ndarray
]]) – the input observationstate (
Optional
[Tuple
[ndarray
,...
]]) – The last hidden states (can be None, used in recurrent policies)episode_start (
Optional
[ndarray
]) – The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset.deterministic (
bool
) – Whether or not to return deterministic actions.
- Return type:
Tuple
[ndarray
,Optional
[Tuple
[ndarray
,...
]]]- Returns:
the model’s action and the next hidden state (used in recurrent policies)
- save(path, exclude=None, include=None)¶
Save all the attributes of the object and the model parameters in a zip-file.
- Parameters:
path (
Union
[str
,Path
,BufferedIOBase
]) – path to the file where the rl agent should be savedexclude (
Optional
[Iterable
[str
]]) – name of parameters that should be excluded in addition to the default onesinclude (
Optional
[Iterable
[str
]]) – name of parameters that might be excluded but should be included anyway
- Return type:
None
- set_env(env, force_reset=True)¶
Checks the validity of the environment, and if it is coherent, set it as the current environment. Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space
- Parameters:
env (
Union
[Env
,VecEnv
]) – The environment for learning a policyforce_reset (
bool
) – Force call toreset()
before training to avoid unexpected behavior. See issue https://github.com/DLR-RM/stable-baselines3/issues/597
- Return type:
None
- set_logger(logger)¶
Setter for for logger object. :rtype:
None
Warning
When passing a custom logger object, this will overwrite
tensorboard_log
andverbose
settings passed to the constructor.
- set_parameters(load_path_or_dict, exact_match=True, device='auto')[source]¶
Load parameters from a given zip-file or a nested dictionary containing parameters for different modules (see
get_parameters
).- Parameters:
load_path_or_iter – Location of the saved data (path or file-like, see
save
), or a nested dictionary containing nn.Module parameters used by the policy. The dictionary maps object names to a state-dictionary returned bytorch.nn.Module.state_dict()
.exact_match (
bool
) – If True, the given parameters should include parameters for each module and each of their parameters, otherwise raises an Exception. If set to False, this can be used to update only specific parameters.device (
Union
[device
,str
]) – Device on which the code should run.
- Return type:
None
- set_random_seed(seed=None)¶
Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space)
- Parameters:
seed (
Optional
[int
]) –- Return type:
None
ARS Policies¶
- class sb3_contrib.ars.policies.ARSPolicy(observation_space, action_space, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, with_bias=True, squash_output=True)[source]
Policy network for ARS.
- Parameters:
observation_space (
Space
) – The observation space of the environmentaction_space (
Space
) – The action space of the environmentnet_arch (
Optional
[List
[int
]]) – Network architecture, defaults to a 2 layers MLP with 64 hidden nodes.activation_fn (
Type
[Module
]) – Activation functionwith_bias (
bool
) – If set to False, the layers will not learn an additive biassquash_output (
bool
) – For continuous actions, whether the output is squashed or not using atanh()
function. If not squashed with tanh the output will instead be clipped.
- forward(obs)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- sb3_contrib.ars.LinearPolicy¶
alias of
ARSLinearPolicy
- sb3_contrib.ars.MlpPolicy¶
alias of
ARSPolicy