Utils

sb3_contrib.common.utils.conjugate_gradient_solver(matrix_vector_dot_fn, b, max_iter=10, residual_tol=1e-10)[source]

Finds an approximate solution to a set of linear equations Ax = b

Sources:
Reference:
Parameters
  • matrix_vector_dot_fn (Callable[[Tensor], Tensor]) – a function that right multiplies a matrix A by a vector v

  • b – the right hand term in the set of linear equations Ax = b

  • max_iter – the maximum number of iterations (default is 10)

  • residual_tol – residual tolerance for early stopping of the solving (default is 1e-10)

Return x

the approximate solution to the system of equations defined by matrix_vector_dot_fn and b

Return type

Tensor

sb3_contrib.common.utils.flat_grad(output, parameters, create_graph=False, retain_graph=False)[source]

Returns the gradients of the passed sequence of parameters into a flat gradient. Order of parameters is preserved.

Parameters
  • output – functional output to compute the gradient for

  • parameters (Sequence[Parameter]) – sequence of Parameter

  • retain_graph (bool) – – If False, the graph used to compute the grad will be freed. Defaults to the value of create_graph.

  • create_graph (bool) – – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default: False.

Return type

Tensor

Returns

Tensor containing the flattened gradients

sb3_contrib.common.utils.quantile_huber_loss(current_quantiles, target_quantiles, cum_prob=None, sum_over_quantiles=True)[source]

The quantile-regression loss, as described in the QR-DQN and TQC papers. Partially taken from https://github.com/bayesgroup/tqc_pytorch.

Parameters
  • current_quantiles (Tensor) – current estimate of quantiles, must be either (batch_size, n_quantiles) or (batch_size, n_critics, n_quantiles)

  • target_quantiles (Tensor) – target of quantiles, must be either (batch_size, n_target_quantiles), (batch_size, 1, n_target_quantiles), or (batch_size, n_critics, n_target_quantiles)

  • cum_prob (Optional[Tensor]) – cumulative probabilities to calculate quantiles (also called midpoints in QR-DQN paper), must be either (batch_size, n_quantiles), (batch_size, 1, n_quantiles), or (batch_size, n_critics, n_quantiles). (if None, calculating unit quantiles)

  • sum_over_quantiles (bool) – if summing over the quantile dimension or not

Return type

Tensor

Returns

the loss