from functools import partial
from typing import Any
import torch as th
from gymnasium import spaces
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BaseModel, BasePolicy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
FlattenExtractor,
create_mlp,
get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from torch import nn
from sb3_contrib.common.torch_layers import BatchRenorm1d
# CAP the standard deviation of the actor
LOG_STD_MAX = 2
LOG_STD_MIN = -20
class Actor(BasePolicy):
"""
Actor network (policy) for CrossQ.
It contains BatchRenorm layers to stabilize and accelerate training.
:param observation_space: Obervation space
:param action_space: Action space
:param net_arch: Network architecture
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: Number of features
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param batch_norm: Whether to use Batch Renorm layers (default=True)
:param batch_norm_momentum: The rate of convergence for the batch renormalization statistics
:param batch_norm_eps: A small value added to the variance to prevent division by zero
:param renorm_warmup_steps: Number of steps to warm up BatchRenorm statistics before the running statistics
are used for normalization.
"""
action_space: spaces.Box
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
net_arch: list[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
full_std: bool = True,
use_expln: bool = False,
clip_mean: float = 2.0,
normalize_images: bool = True,
batch_norm: bool = True,
batch_norm_momentum: float = 0.01,
batch_norm_eps: float = 0.001,
renorm_warmup_steps: int = 100_000,
):
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
squash_output=True,
)
# Save arguments to re-create object at loading
self.use_sde = use_sde
self.sde_features_extractor = None
self.net_arch = net_arch
self.features_dim = features_dim
self.activation_fn = activation_fn
self.log_std_init = log_std_init
self.use_expln = use_expln
self.full_std = full_std
self.clip_mean = clip_mean
action_dim = get_action_dim(self.action_space)
pre_linear_modules = []
if batch_norm:
pre_linear_modules = [
partial(
BatchRenorm1d,
momentum=batch_norm_momentum,
eps=batch_norm_eps,
warmup_steps=renorm_warmup_steps,
)
]
latent_pi_net = create_mlp(
features_dim,
-1,
net_arch,
activation_fn,
pre_linear_modules=pre_linear_modules, # type: ignore[arg-type]
)
if batch_norm and net_arch:
latent_pi_net.append(pre_linear_modules[0](net_arch[-1]))
self.latent_pi = nn.Sequential(*latent_pi_net)
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
if self.use_sde:
self.action_dist = StateDependentNoiseDistribution(
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
)
self.mu, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
)
# Avoid numerical issues by limiting the mean of the Gaussian
# to be in [-clip_mean, clip_mean]
if clip_mean > 0.0:
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
else:
self.action_dist = SquashedDiagGaussianDistribution(action_dim) # type: ignore[assignment]
self.mu = nn.Linear(last_layer_dim, action_dim)
self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment]
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
full_std=self.full_std,
use_expln=self.use_expln,
features_extractor=self.features_extractor,
clip_mean=self.clip_mean,
)
)
return data
def get_std(self) -> th.Tensor:
"""
Retrieve the standard deviation of the action distribution.
Only useful when using gSDE.
It corresponds to ``th.exp(log_std)`` in the normal case,
but is slightly different when using ``expln`` function
(cf StateDependentNoiseDistribution doc).
:return: Standard deviation of the action dist when available.
"""
msg = "get_std() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
return self.action_dist.get_std(self.log_std)
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size:
"""
msg = "reset_noise() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
def get_action_dist_params(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor, dict[str, th.Tensor]]:
"""
Get the parameters for the action distribution.
:param obs:
:return:
Mean, standard deviation and optional keyword arguments.
"""
features = self.extract_features(obs, self.features_extractor)
latent_pi = self.latent_pi(features)
mean_actions = self.mu(latent_pi)
if self.use_sde:
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
# Unstructured exploration (Original implementation)
log_std = self.log_std(latent_pi) # type: ignore[operator]
# Original Implementation to cap the standard deviation
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean_actions, log_std, {}
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# Note: the action is squashed
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
def action_log_prob(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# return action and associated log prob
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self(observation, deterministic)
def set_bn_training_mode(self, mode: bool) -> None:
"""
Set the training mode of the BatchRenorm layers.
When training is True, the running statistics are updated.
:param mode: Whether to set the layers in training mode or not
"""
for module in self.modules():
if isinstance(module, BatchRenorm1d):
module.train(mode)
class CrossQCritic(BaseModel):
"""
Critic network(s) for CrossQ.
The difference with standard critic networks used by SAC/TD3 is that it uses BatchRenorm layers.
By default, it creates two critic networks used to reduce overestimation
thanks to clipped Q-learning (cf TD3 paper).
:param observation_space: Obervation space
:param action_space: Action space
:param net_arch: Network architecture
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: Number of features
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether the features extractor is shared or not
between the actor and the critic (this saves computation time)
:param batch_norm: Whether to use Batch Renorm layers (default=True)
:param batch_norm_momentum: The rate of convergence for the batch renormalization statistics
:param batch_norm_eps: A small value added to the variance to prevent division by zero
:param renorm_warmup_steps: Number of steps to warm up BatchRenorm statistics before the running statistics
are used for normalization.
"""
features_extractor: BaseFeaturesExtractor
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
net_arch: list[int],
features_extractor: BaseFeaturesExtractor,
features_dim: int,
activation_fn: type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
n_critics: int = 2,
share_features_extractor: bool = True,
batch_norm: bool = True,
batch_norm_momentum: float = 0.01,
batch_norm_eps: float = 0.001,
renorm_warmup_steps: int = 100_000,
):
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
)
action_dim = get_action_dim(self.action_space)
pre_linear_modules = []
if batch_norm:
pre_linear_modules = [
partial(
BatchRenorm1d,
momentum=batch_norm_momentum,
eps=batch_norm_eps,
warmup_steps=renorm_warmup_steps,
)
]
self.share_features_extractor = share_features_extractor
self.n_critics = n_critics
self.q_networks: list[nn.Module] = []
for idx in range(n_critics):
q_net_list = create_mlp(
features_dim + action_dim,
1,
net_arch,
activation_fn,
pre_linear_modules=pre_linear_modules, # type: ignore[arg-type]
)
q_net = nn.Sequential(*q_net_list)
self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)
def forward(self, obs: th.Tensor, actions: th.Tensor) -> tuple[th.Tensor, ...]:
# Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):
features = self.extract_features(obs, self.features_extractor)
qvalue_input = th.cat([features, actions], dim=1)
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
def set_bn_training_mode(self, mode: bool) -> None:
"""
Set the training mode of the BatchRenorm layers.
When training is True, the running statistics are updated.
:param mode: Whether to set the layers in training mode or not
"""
for module in self.modules():
if isinstance(module, BatchRenorm1d):
module.train(mode)
[docs]
class CrossQPolicy(BasePolicy):
"""
Policy class (with both actor and critic) for CrossQ.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param batch_norm: Whether to use Batch Renorm layers (default=True)
:param batch_norm_momentum: The rate of convergence for the batch renormalization statistics
:param batch_norm_eps: A small value added to the variance to prevent division by zero
:param renorm_warmup_steps: Number of steps to warm up BatchRenorm statistics before the running statistics
are used for normalization.
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
actor: Actor
critic: CrossQCritic
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: list[int] | dict[str, list[int]] | None = None,
activation_fn: type[nn.Module] = nn.ReLU,
batch_norm: bool = True,
batch_norm_momentum: float = 0.01, # Note: Jax implementation is 1 - momentum = 0.99
batch_norm_eps: float = 0.001,
renorm_warmup_steps: int = 100_000,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: dict[str, Any] | None = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
if optimizer_kwargs is None:
# Note: the default value for b1 is 0.9 in Adam.
# b1=0.5 is used in the original CrossQ implementation
# but shows only little overall improvement.
optimizer_kwargs = {}
if optimizer_class in [th.optim.Adam, th.optim.AdamW]:
optimizer_kwargs["betas"] = (0.5, 0.999)
super().__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=True,
normalize_images=normalize_images,
)
if net_arch is None:
# While CrossQ already works with a [256,256] critic network,
# the authors found that a wider network significantly improves performance.
# We use a slightly smaller net for faster computation, [1024, 1024] instead of [2048, 2048] in the paper.
net_arch = {"pi": [256, 256], "qf": [1024, 1024]}
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
self.batch_norm_params = {
"batch_norm": batch_norm,
"batch_norm_momentum": batch_norm_momentum,
"batch_norm_eps": batch_norm_eps,
"renorm_warmup_steps": renorm_warmup_steps,
}
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"net_arch": actor_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
**self.batch_norm_params,
}
self.actor_kwargs = self.net_args.copy()
sde_kwargs = {
"use_sde": use_sde,
"log_std_init": log_std_init,
"use_expln": use_expln,
"clip_mean": clip_mean,
}
self.actor_kwargs.update(sde_kwargs)
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update(
{
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
}
)
self.share_features_extractor = share_features_extractor
self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None:
self.actor = self.make_actor()
self.actor.optimizer = self.optimizer_class(
self.actor.parameters(),
lr=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
)
if self.share_features_extractor:
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
# Do not optimize the shared features extractor with the critic loss
# otherwise, there are gradient computation issues
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
else:
# Create a separate features extractor for the critic
# this requires more memory and computation
self.critic = self.make_critic(features_extractor=None)
critic_parameters = list(self.critic.parameters())
self.critic.optimizer = self.optimizer_class(
critic_parameters,
lr=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
)
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
activation_fn=self.net_args["activation_fn"],
use_sde=self.actor_kwargs["use_sde"],
log_std_init=self.actor_kwargs["log_std_init"],
use_expln=self.actor_kwargs["use_expln"],
clip_mean=self.actor_kwargs["clip_mean"],
n_critics=self.critic_kwargs["n_critics"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
**self.batch_norm_params,
)
)
return data
[docs]
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size:
"""
self.actor.reset_noise(batch_size=batch_size)
def make_actor(self, features_extractor: BaseFeaturesExtractor | None = None) -> Actor:
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
return Actor(**actor_kwargs).to(self.device)
def make_critic(self, features_extractor: BaseFeaturesExtractor | None = None) -> CrossQCritic:
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return CrossQCritic(**critic_kwargs).to(self.device)
[docs]
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self.actor(observation, deterministic)
[docs]
def set_training_mode(self, mode: bool) -> None:
"""
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
:param mode: if true, set to training mode, else set to evaluation mode
"""
self.actor.set_training_mode(mode)
self.critic.set_training_mode(mode)
self.training = mode
MlpPolicy = CrossQPolicy