java 求标准差代码实现

package com.qingyuan.huake;

/**
* reference: <[http://my.oschina.net/BreathL/blog/41063]>
* 对样本进行统计, 计算样本的方差&标准差
* Sn = Sn-1 + (xn - Mn-1)*(xn - Mn)
* Mn = Mn-1+ (xn - Mn-1)/n
*/
public final class StandardDeviation
{
/** 样本的个数*/
private int count;

/** 平均值*/
private double averageVar;

/** sn样本方差*/
private double standardDeviationSum;

/** 样本标准差*/
private double standard_Deviation;

public StandardDeviation()
{
this(0, 0.0, 0.0);
}

public StandardDeviation(int count, double standardDeviationSum, double averageVar)
{
this.count = count;
this.standardDeviationSum = standardDeviationSum;
this.averageVar = averageVar;
recomputerstandard_Deviation();
}

public synchronized int getCount()
{
return count;
}

private void recomputerstandard_Deviation()
{
int count = getCount();
standard_Deviation = count > 1 ? Math.sqrt(standardDeviationSum / (count -1)) : Double.NaN;
}

/**
*  获取运行时样本的方差
*  @return double
*/
public synchronized double getRunningVariance()
{
return standard_Deviation;
}

/**
* 增加一个样本时重新计算
* @param sample void
*/
{
if (++count == 1)
{
averageVar = sample;
standardDeviationSum = 0.0;
}
else
{
double oldaverageVar = averageVar;
double diff = sample - oldaverageVar;
averageVar += diff / count;
standardDeviationSum += diff * (sample - averageVar);
}

recomputerstandard_Deviation();
}

/**
* 移除一个样本时重新计算
* @param sample void
*/
public synchronized void removeSample(double sample)
{
int oldCount = getCount();
double oldaverageVar = averageVar;

if (oldCount == 0)
{
throw new IllegalStateException();
}

if (--count == 0)
{
averageVar = Double.NaN;
standardDeviationSum = Double.NaN;
}
else
{
averageVar = (oldCount * oldaverageVar - sample) / (oldCount - 1);
standardDeviationSum = (sample - averageVar) * (sample - oldaverageVar);
}

recomputerstandard_Deviation();
}

}

package com.qingyuan.huake;

import static org.junit.Assert.assertEquals;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import org.junit.Test;

public class TestSD
{
@Test
public void test()
{
StandardDeviation s = new StandardDeviation();

int[] arr = {2, 4, 4, 5, 5, 6, 2, 3, 3, 6};
int sum = 0, average =0 ;
double sn  = 0.0;
double standardDeviation = 0.0;

for (int i = 0, len = arr.length; i < len; i++)
{
sum = sum + arr[i];
}

average = sum / 10;
System.out.println("average value is = " + average);

for (int j = 0, len = arr.length; j < len; j++)
{
sn = sn + Math.pow((arr[j] - average), 2);
}

// 测试数据计算
standardDeviation = Math.sqrt(sn / arr.length);
System.out.println(standardDeviation);

// 通过函数计算
double result = s.getRunningVariance();
System.out.println(result);

// 截取小数点后两位判断是否相等
NumberFormat formatter = new DecimalFormat("#0.");
String res = formatter.format(result);
String std = formatter.format(standardDeviation);

assertEquals(std, res);
}

}