statistic_tools

class xuance.common.statistic_tools.RunningMeanStd(shape: Sequence[int] | dict, epsilon=0.0001, comm=None, use_mpi=False)[源代码]

基类:object

Maintains a running mean and standard deviation.

shape

Shape of the input data.

Type:

Union[Sequence[int], dict]

epsilon

Small value to prevent division by zero.

Type:

float

comm

MPI communicator for distributed computation.

Type:

MPI.Comm

use_mpi

Whether to use MPI for computation.

Type:

bool

property std

Compute the standard deviation.

返回:

The standard deviation of the running statistics.

返回类型:

Union[dict, ndarray]

update(x)[源代码]

Update the running mean and standard deviation with new data.

参数:

x (Union[dict, ndarray]) – New data to update the statistics.

update_from_moments(batch_mean, batch_var, batch_count)[源代码]

Update the running mean, variance, and count using new statistics.

This method updates the current statistics by combining them with batch-level statistics, supporting both dictionary and array inputs.

参数:
  • batch_mean (Union[dict, ndarray]) – Mean of the new batch.

  • batch_var (Union[dict, ndarray]) – Variance of the new batch.

  • batch_count (Union[dict, int]) – Number of samples in the new batch.

Updates:

self.mean (Union[dict, ndarray]): Updated running mean. self.var (Union[dict, ndarray]): Updated running variance. self.count (Union[dict, float]): Updated sample count.

xuance.common.statistic_tools.mpi_mean(x, axis=0, comm=None, keepdims=False)[源代码]

Compute the mean across all MPI processes.

参数:
  • x (array-like) – Input array.

  • axis (int) – Axis along which the mean is computed.

  • comm (MPI.Comm, optional) – MPI communicator. Defaults to MPI.COMM_WORLD.

  • keepdims (bool) – Whether to keep the dimensions of the result. Defaults to False.

返回:

A tuple containing:
  • mean (ndarray): Mean of the input array.

  • count (int): Total count used for the mean computation.

返回类型:

tuple

xuance.common.statistic_tools.mpi_moments(x, axis=0, comm=None, keepdims=False)[源代码]

Compute the mean and standard deviation across MPI processes.

参数:
  • x (array-like) – Input array for which to calculate moments.

  • axis (int, optional) – Axis along which to compute the mean and standard deviation. Defaults to 0.

  • comm (MPI.Comm, optional) – MPI communicator for distributed computation. If None, defaults to MPI.COMM_WORLD.

  • keepdims (bool, optional) – Whether to retain reduced dimensions in the output. Defaults to False.

返回:

A tuple containing:
  • mean (ndarray): Mean of the input array.

  • std (ndarray): Standard deviation of the input array.

  • count (int): Total count used for the computation.

返回类型:

tuple