Utils
- sb3_contrib.common.utils.conjugate_gradient_solver(matrix_vector_dot_fn, b, max_iter=10, residual_tol=1e-10)[source]
Finds an approximate solution to a set of linear equations Ax = b
- Sources:
- Reference:
- Parameters:
matrix_vector_dot_fn (Callable[[Tensor], Tensor]) – a function that right multiplies a matrix A by a vector v
b – the right hand term in the set of linear equations Ax = b
max_iter – the maximum number of iterations (default is 10)
residual_tol – residual tolerance for early stopping of the solving (default is 1e-10)
- Return x:
the approximate solution to the system of equations defined by matrix_vector_dot_fn and b
- Return type:
Tensor
- sb3_contrib.common.utils.flat_grad(output, parameters, create_graph=False, retain_graph=False)[source]
Returns the gradients of the passed sequence of parameters into a flat gradient. Order of parameters is preserved.
- Parameters:
output – functional output to compute the gradient for
parameters (Sequence[Parameter]) – sequence of
Parameter
retain_graph (bool) – If
False
, the graph used to compute the grad will be freed. Defaults to the value ofcreate_graph
.create_graph (bool) – If
True
, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default:False
.
- Returns:
Tensor containing the flattened gradients
- Return type:
Tensor
- sb3_contrib.common.utils.quantile_huber_loss(current_quantiles, target_quantiles, cum_prob=None, sum_over_quantiles=True)[source]
The quantile-regression loss, as described in the QR-DQN and TQC papers. Partially taken from https://github.com/bayesgroup/tqc_pytorch.
- Parameters:
current_quantiles (Tensor) – current estimate of quantiles, must be either (batch_size, n_quantiles) or (batch_size, n_critics, n_quantiles)
target_quantiles (Tensor) – target of quantiles, must be either (batch_size, n_target_quantiles), (batch_size, 1, n_target_quantiles), or (batch_size, n_critics, n_target_quantiles)
cum_prob (Tensor | None) – cumulative probabilities to calculate quantiles (also called midpoints in QR-DQN paper), must be either (batch_size, n_quantiles), (batch_size, 1, n_quantiles), or (batch_size, n_critics, n_quantiles). (if None, calculating unit quantiles)
sum_over_quantiles (bool) – if summing over the quantile dimension or not
- Returns:
the loss
- Return type:
Tensor