java 求标准差代码实现

package com.qingyuan.huake;

/**
 * reference: <[http://my.oschina.net/BreathL/blog/41063]>
 * 对样本进行统计, 计算样本的方差&标准差
 * Sn = Sn-1 + (xn - Mn-1)*(xn - Mn)
 * Mn = Mn-1+ (xn - Mn-1)/n
 */
public final class StandardDeviation
{
    /** 样本的个数*/
    private int count; 
    
    /** 平均值*/
    private double averageVar;
    
    /** sn样本方差*/
    private double standardDeviationSum;
    
    /** 样本标准差*/
    private double standard_Deviation;
    
    public StandardDeviation()
    {
        this(0, 0.0, 0.0);
    }

    public StandardDeviation(int count, double standardDeviationSum, double averageVar)
    {
        this.count = count;
        this.standardDeviationSum = standardDeviationSum;
        this.averageVar = averageVar;
        recomputerstandard_Deviation();
    }

    public synchronized int getCount()
    {
        return count;
    }
    
    private void recomputerstandard_Deviation()
    {
        int count = getCount();
        standard_Deviation = count > 1 ? Math.sqrt(standardDeviationSum / (count -1)) : Double.NaN;
    } 
    
    /**
     *  获取运行时样本的方差
     *  @return double
     */
    public synchronized double getRunningVariance()
    {
        return standard_Deviation;
    }
    
    /**
     * 增加一个样本时重新计算
     * @param sample void
     */
    public synchronized void addSample(double sample)
    {
        if (++count == 1)
        {
            averageVar = sample;
            standardDeviationSum = 0.0;
        }
        else
        {
            double oldaverageVar = averageVar;
            double diff = sample - oldaverageVar;
            averageVar += diff / count;
            standardDeviationSum += diff * (sample - averageVar);
        }

        recomputerstandard_Deviation();
    }
    
    /**
     * 移除一个样本时重新计算
     * @param sample void
     */
    public synchronized void removeSample(double sample)
    {
        int oldCount = getCount();
        double oldaverageVar = averageVar;
        
        if (oldCount == 0)
        {
            throw new IllegalStateException();
        }
        
        if (--count == 0)
        {
            averageVar = Double.NaN;
            standardDeviationSum = Double.NaN;
        }
        else
        {
            averageVar = (oldCount * oldaverageVar - sample) / (oldCount - 1);
            standardDeviationSum = (sample - averageVar) * (sample - oldaverageVar);
        }
        
        recomputerstandard_Deviation();
    }
    
}

package com.qingyuan.huake;

import static org.junit.Assert.assertEquals;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import org.junit.Test;

public class TestSD
{
    @Test
    public void test()
    {
        StandardDeviation s = new StandardDeviation();
        
        int[] arr = {2, 4, 4, 5, 5, 6, 2, 3, 3, 6};
        int sum = 0, average =0 ;
        double sn  = 0.0;
        double standardDeviation = 0.0;
         
        for (int i = 0, len = arr.length; i < len; i++)
        {
            sum = sum + arr[i];
            s.addSample(arr[i]);
        }
        
        average = sum / 10;
        System.out.println("average value is = " + average);
        
        for (int j = 0, len = arr.length; j < len; j++)
        {
            sn = sn + Math.pow((arr[j] - average), 2);
        }
             
        // 测试数据计算
        standardDeviation = Math.sqrt(sn / arr.length);
        System.out.println(standardDeviation);
        
        // 通过函数计算
        double result = s.getRunningVariance();
        System.out.println(result);
        
        // 截取小数点后两位判断是否相等
        NumberFormat formatter = new DecimalFormat("#0.");
        String res = formatter.format(result);
        String std = formatter.format(standardDeviation);
        
        assertEquals(std, res);
    }
	
}

阅读更多
想对作者说点什么?

博主推荐

换一批

没有更多推荐了,返回首页