Torch Layers
- class sb3_contrib.common.torch_layers.BatchRenorm(num_features, eps=0.001, momentum=0.01, affine=True, warmup_steps=100000)[source]
BatchRenorm Module (https://arxiv.org/abs/1702.03275). Adapted to Pytorch from https://github.com/araffin/sbx/blob/master/sbx/common/jax_layers.py
BatchRenorm is an improved version of vanilla BatchNorm. Contrary to BatchNorm, BatchRenorm uses the running statistics for normalizing the batches after a warmup phase. This makes it less prone to suffer from “outlier” batches that can happen during very long training runs and, therefore, is more robust during long training runs.
During the warmup phase, it behaves exactly like a BatchNorm layer. After the warmup phase, the running statistics are used for normalization. The running statistics are updated during training mode. During evaluation mode, the running statistics are used for normalization but not updated.
- Parameters:
num_features (int) – Number of features in the input tensor.
eps (float) – A value added to the variance for numerical stability.
momentum (float) – The value used for the ra_mean and ra_var (running average) computation. It controls the rate of convergence for the batch renormalization statistics.
affine (bool) – A boolean value that when set to True, this module has learnable affine parameters. Default: True
warmup_steps (int) – Number of warum steps that are performed before the running statistics are used for normalization. During the warump phase, the batch statistics are used.