statistic_tools¶
- class xuance.common.statistic_tools.RunningMeanStd(shape: Sequence[int] | dict, epsilon=0.0001, comm=None, use_mpi=False)[源代码]¶
基类:
objectMaintains a running mean and standard deviation.
- shape¶
Shape of the input data.
- Type:
Union[Sequence[int], dict]
- epsilon¶
Small value to prevent division by zero.
- Type:
float
- comm¶
MPI communicator for distributed computation.
- Type:
MPI.Comm
- use_mpi¶
Whether to use MPI for computation.
- Type:
bool
- property std¶
Compute the standard deviation.
- 返回:
The standard deviation of the running statistics.
- 返回类型:
Union[dict, ndarray]
- update(x)[源代码]¶
Update the running mean and standard deviation with new data.
- 参数:
x (Union[dict, ndarray]) – New data to update the statistics.
- update_from_moments(batch_mean, batch_var, batch_count)[源代码]¶
Update the running mean, variance, and count using new statistics.
This method updates the current statistics by combining them with batch-level statistics, supporting both dictionary and array inputs.
- 参数:
batch_mean (Union[dict, ndarray]) – Mean of the new batch.
batch_var (Union[dict, ndarray]) – Variance of the new batch.
batch_count (Union[dict, int]) – Number of samples in the new batch.
- Updates:
self.mean (Union[dict, ndarray]): Updated running mean. self.var (Union[dict, ndarray]): Updated running variance. self.count (Union[dict, float]): Updated sample count.
- xuance.common.statistic_tools.mpi_mean(x, axis=0, comm=None, keepdims=False)[源代码]¶
Compute the mean across all MPI processes.
- 参数:
x (array-like) – Input array.
axis (int) – Axis along which the mean is computed.
comm (MPI.Comm, optional) – MPI communicator. Defaults to MPI.COMM_WORLD.
keepdims (bool) – Whether to keep the dimensions of the result. Defaults to False.
- 返回:
- A tuple containing:
mean (ndarray): Mean of the input array.
count (int): Total count used for the mean computation.
- 返回类型:
tuple
- xuance.common.statistic_tools.mpi_moments(x, axis=0, comm=None, keepdims=False)[源代码]¶
Compute the mean and standard deviation across MPI processes.
- 参数:
x (array-like) – Input array for which to calculate moments.
axis (int, optional) – Axis along which to compute the mean and standard deviation. Defaults to 0.
comm (MPI.Comm, optional) – MPI communicator for distributed computation. If None, defaults to MPI.COMM_WORLD.
keepdims (bool, optional) – Whether to retain reduced dimensions in the output. Defaults to False.
- 返回:
- A tuple containing:
mean (ndarray): Mean of the input array.
std (ndarray): Standard deviation of the input array.
count (int): Total count used for the computation.
- 返回类型:
tuple