H2O中添加算法-模型训练篇

H2O中添加算法-模型训练篇

准备知识:
1.从githup上下载源代码进行构建。并在本地编译出H2o.jar,运行后能够正常访问;
2.了解H2O启动过程,加载静态资源过程;
3.了解模型构建过程,断点跟踪,初始化模型(ModelBuilder)此类作为所有自建模型的父类。
在这里插入图片描述

  1. 定义模型和算法参数的初始化:
    模型和算法参数定义(前端传输模型参数):在hex.schemas下定义自己的模型和算法参数类,用于算法和模型中的超参数(自定义值)的设置。在这里插入图片描述
    在这里插入图片描述

接收方式:在MyAlgorithmParamters中定义和hex.schames.MyAlgorithmV3中的属性名称前面加上下横线,名称一致即可接收。

在这里插入图片描述
在这里插入图片描述

模型参数输出:
在这里插入图片描述在这里插入图片描述

猜测(源码码没看懂):通过框架调用继承至MRTask类的子类中的map/reduce方法(自动调用)来实现对maxs的初始化,进而将maxs填充入模型的输出参数中,输出到页面中进行展示。
在这里插入图片描述
在这里插入图片描述

页面当中的展示效果:
在这里插入图片描述
在这里插入图片描述

5.使用java代码实现模型算法(已有实现代码,这是基础),这里只展示如何将实现好的代码整合进入H2O中,同时模型的训练过程也要自己实现。
a.初始化模型所需要的框架结构代码,在h2o-algos下,hex包下新建一个包,如my
在这里插入图片描述

新建好包后,新建两个空类,类继承ModelBuilder类。

在这里插入图片描述

package hex.my;

import hex.ModelBuilder;
import hex.ModelCategory;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import java.util.Arrays;
public class MyAlgorithm extends ModelBuilder<MyAlgorithmModel, MyAlgorithmModel.MyAlgorithmParameters, MyAlgorithmModel.MyAlgorithmOutput> {
    //从http请求中调用
    public MyAlgorithm(MyAlgorithmModel.MyAlgorithmParameters parms) { super(parms);init(false); }
    public MyAlgorithm(boolean startup_once) { super(new MyAlgorithmModel.MyAlgorithmParameters(),startup_once); }
    //是否能够下载pojo
    //@Override public boolean havePojo() { return true; }
    //是否能够下载mojo
    //@Override public boolean haveMojo() { return false; }
    /**
     * 初始化ModelBuilder,验证所有参数并准备训练框架。此调用应在子类中重写
     * @param expensive
     */
    @Override
    public void init(boolean expensive) {
        super.init(expensive);
        //处理响应列,页面中设置的响应列,此算法是否具有响应列
        if (_response != null) {
            //判断是否为分类列,如果不是抛出错误信息
            if (!_response.isCategorical()) error("_response", "Response must be a categorical column");
            //如果列仅包含常量值且不包含NAs,则为True
            else if (_response.isConst()) error("_response", "Response must have at least two unique categorical levels");
        }
        //如果没有错误发生,checkMemoryFootPrint确保最终的模型能够放入内存。注意:不应重写此方法
        // (而是重写checkMemoryFootPrint_impl)。
        //破坏第三方实现并不是“最终”声明。如果有必要的话,将来可能会宣布它是最终的
        if (expensive && error_count() == 0) checkMemoryFootPrint();
    }
    @Override
    protected Driver trainModelImpl() {
        return new MyAlgorithmDriver();
    }
    @Override
    public ModelCategory[] can_build() {
        return new ModelCategory[0];
    }
    //是否作为监督学习算法
    @Override
    public boolean isSupervised() {
        return false;
    }
    /* F/J{@link CountedCompleter}上支持优先级的简单包装排队。F/J队列是简单无序的(而且非常轻)排队。
     *然而,我们经常需要优先事项来避免僵局和提高有效吞吐量(例如未能快速响应{@linkTaskGetKey}可以阻塞
     * 整个节点,因为缺少数据)。所以每次尝试做低优先级的F/J工作都是从尝试工作并排出高优先级队列。
     * */
    private class MyAlgorithmDriver extends Driver {
        @Override public void computeImpl() {
            MyAlgorithmModel model = null;
            try {
                init(true);
                // The model to be built  待建模型   _job._result结果键
                model = new MyAlgorithmModel(_job._result, _parms, new MyAlgorithmModel.MyAlgorithmOutput(MyAlgorithm.this));
                /*通过{@code job{u key}编写锁{@code this.{u key},并删除任何先前的映射。
                 *如果密钥已锁定,则抛出IAE。被作业密钥锁定
                 */
                model.delete_and_lock(_job);
            // Run the main Example Loop运行主示例循环,迭代足够后停止
            // Stop after enough iterations
            for( ; model._output._iterations < _parms._max_iterations; model._output._iterations++ ) {
                if( stop_requested() ) break; // Stopped/cancelled
                //如果用键调用doAll,则调用在密钥的主节点上为每个密钥生成
                double[] maxs = new Max(_job._key).doAll(_parms.train())._maxs;
                // Fill in the model 填充模型
                model._output._maxs = maxs;
                model.update(_job);   // Update model in K/V store  K/V存储中的更新模型
                _job.update(1);// One unit of work  一个工作单元,更新为此任务完成的新任务

                StringBuilder sb = new StringBuilder();
                sb.append("Example: iter: ").append(model._output._iterations);
                //Log.info(sb);
            }
        } finally {
            if( model != null ) model.unlock(_job);
        }
    }
}
/**
 * MRTask map/reduce 分布式计算,将数据集分成块来进行map/reduce计算
 * Find max per-column  每列查找最大值
 */
private static class Max extends MRTask<Max> {
    final protected Key<Job> _jobKey;
    // IN
    // OUT
    double[] _maxs;
    private Max(Key<Job> jobKey) {
        _jobKey = jobKey;
    }
    @Override public void map(Chunk[] cs) {
        //Chunk,H2O中的块数据(数据压缩方式),便于集群处理,支持类似于数组的API操作
        //根据传入的数据大小列数创建一个数组
        _maxs = new double[cs.length];
        //初始化_maxs数组,将double类型的最大值填充入数组中
        Arrays.fill(_maxs,-Double.MAX_VALUE);
        //map函数每一列的最大值使用块相对行号加载{@code double}值。返回Double.NaN
        //如果缺少值。
        //返回给定行的双精度值,如果缺少该值,则返回NaN
        for( int col = 0; col < cs.length; col++ )
            for( int row = 0; row < cs[col]._len; row++ )
                _maxs[col] = Math.max(_maxs[col],cs[col].atd(row));
    }

    /**
     *任务
     * @param that
     */
    @Override public void reduce(Max that) {
        for( int col = 0; col < _maxs.length; col++ )
            _maxs[col] = Math.max(_maxs[col],that._maxs[col]);
    }
}

}

MyAlgorithmsModel类:
package hex.my;

import hex.Model;
import hex.ModelMetrics;
import water.Key;

public class MyAlgorithmModel extends Model<MyAlgorithmModel,MyAlgorithmModel.MyAlgorithmParameters,MyAlgorithmModel.MyAlgorithmOutput> {
/**
* Full constructor
*
* @param selfKey
* @param parms
* @param output
*/
public MyAlgorithmModel(Key selfKey, MyAlgorithmParameters parms, MyAlgorithmOutput output) {
super(selfKey, parms, output);
}

@Override
public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
    return null;
}

@Override
protected double[] score0(double[] data, double[] preds) {
    return new double[0];
}

public static class MyAlgorithmParameters extends Model.Parameters {
    public String algoName() { return "MyAlgorithm"; }
    public String fullName() { return "MyAlgorithm"; }
    public String javaName() { return MyAlgorithmModel.class.getName(); }
    //控制迭代次数
    @Override public long progressUnits() { return _max_iterations; }
    public int _max_iterations;
}

/**
 * 输出参数
 */
public static class MyAlgorithmOutput extends Model.Output {
    public int _iterations;
    public double[] _maxs;
    public double[] _max=new double[]{20.0,20.2,201.2,2000000.0};
    public MyAlgorithmOutput(MyAlgorithm b) { super(b); }
}

}

b.构建模型和算法所需参数设置和参数输出,在hex.schemas包下新建两个类,需要和上面的类对应。在这里插入图片描述

MyAlgorithmV3:此类用来向页面显示模型和算法所需的超参数。

package hex.schemas;
import hex.my.MyAlgorithm;
import hex.my.MyAlgorithmModel;
import water.api.API;
import water.api.schemas3.ModelParametersSchemaV3;
public class MyAlgorithmV3 extends ModelBuilderSchema<MyAlgorithm, MyAlgorithmV3, MyAlgorithmV3.MyAlgorithmParametersV3> {
  /**
   * 输入参数
   */
  public static final class MyAlgorithmParametersV3 extends ModelParametersSchemaV3<MyAlgorithmModel.MyAlgorithmParameters, MyAlgorithmParametersV3> {
    //在界面中能够对那些参数进行设置
    static public String[] fields = new String[] {
            "training_frame","ignored_columns","max_iterations","hello"
    };
    // 输入参数,在界面上显示对参数的说明
    @API(help="最大迭代次数")  public int max_iterations;
    // 输入参数,在界面上显示对参数的说明
    @API(help="我的参数,世界你好!")  public int hello;
  }
}


MyAlgorithmModelV3:此类用作向页面显示模型计算好的需要展示的参数:
package hex.schemas;
import hex.my.MyAlgorithmModel;
import water.api.API;
import water.api.schemas3.ModelOutputSchemaV3;
import water.api.schemas3.ModelSchemaV3;
import water.api.schemas3.TwoDimTableV3;

public class MyAlgorithmModelV3 extends ModelSchemaV3<MyAlgorithmModel, MyAlgorithmModelV3, MyAlgorithmModel.MyAlgorithmParameters, MyAlgorithmV3.MyAlgorithmParametersV3, MyAlgorithmModel.MyAlgorithmOutput, MyAlgorithmModelV3.MyAlgorithmModelOutputV3> {
  /**
   * 输出参数
   */
  public static final class MyAlgorithmModelOutputV3 extends ModelOutputSchemaV3<MyAlgorithmModel.MyAlgorithmOutput, MyAlgorithmModelOutputV3> {
    // Output fields; input fields are in the parameters list
    //输出字段,在页面显示;输入字段在参数列表中
    @API(help="执行的迭代次数") public int iterations;
    @API(help="每一列最大数集合") public double[] maxs;
    @API(help="我的输出集合") public double[] max;
  }

  // TODO: I think we can implement the following two in ModelSchemaV3, using reflection on the type parameters.
  //我认为我们可以在ModelSchemaV3中使用类型参数上的反射实现以下两个。
  public MyAlgorithmV3.MyAlgorithmParametersV3 createParametersSchema() { return new MyAlgorithmV3.MyAlgorithmParametersV3(); }
  public MyAlgorithmModelOutputV3 createOutputSchema() { return new MyAlgorithmModelOutputV3(); }

//  // Version&Schema-specific filling into the impl
    //impl中特定于版本和模式的填充
//  @Override public MyAlgorithmModel createImpl() {
//    MyAlgorithmModel.MyAlgorithmParameters parms = parameters.createImpl();
//    return new MyAlgorithmModel( model_id.key(), parms, null );
//  }
}

c.上述类建好后,还需要在h2o中注册算法和注册访问的api,在当前项目resource下water.api.Schema中注册api;在hex.api.RegisterAlgos中注册算法。
在这里插入图片描述
在这里插入图片描述

编译启动:此模型算法依照H2O中的example这个例子所构建,算法实现一个很简单的功能,能够得出上传文件中的每一列中的最大值。
在这里插入图片描述
在这里插入图片描述

6.改造当前实现的这个MyAlgorithms算法,这里所实现的是通过神经网络学习识别整数的正负(参考网页:感谢转载自http://www.cppcns.com/ruanjian/java/215188.html),这个感知器神经网络比较简单,是适用于可线性划分的数据,比如一维的话正数和负数,二维的坐标象限分类;对于不可线性划分的数据无法进行正确的分类,如寻找质数等。
在这里插入图片描述

如果稍微对神经网络感兴趣的一定对这张图不陌生,这张图是神经元的结构图
X1Xm表示输入,W1Wm表示突触权值,Σ表示求和结点,Activation function表示激活函数,之后输出一个结果,具体的流程是
神经元接收到输入,每个输入都会与其相对路径上的权值相乘,到了求和结点进行求和,这里把求和结点的结果设为z :
z = X1 * W1 + X2 * W2 + X3 * W3 + … + Xm * Wm
之后将 z 传入到激活函数(这里我们称激活函数为 f)进行二分类模式识别 :
在这里插入图片描述

这里可以看出,如果 f(x) 的值大于阈值,得到分类 y = 1,反之 y = -1
注:相对于生物神经元受到刺激表示的反应,如果刺激在可接受范围之内,则神经元会抑制刺激(y = -1),如果超过范围则会兴奋(y = 1),而这个范围的分水岭就是阈值(e)。学习
我们发现,如果权值和阈值都固定的话,那么这个神经网络就没有存在的意义了,所以我们引入学习的概念,通过学习,让神经网络去修改权值和阈值,从而可以动态的修正模式识别的正确率,这才是机器学习的本质。
那么如何学习呢?当我们在使用之前我们需要提供给此网络一组样本数据(这里采取的是有教师模式学习),样本数据包括输入数据x和正确的识别结果y’。
当我们输入训练数据x得到模式识别y之后进行判断,如果 y != y’ ,则会去调整此网络的权值和阈值,调整请看公式,μ 表示学习率(修正率),update 表示需要修正值:
在这里插入图片描述

当感知器分类结果等于正确分类,update = 0,不调整网络;如果不等于正确分类,则会调整全部的权值(w)与阈值(e)
以上就是我所介绍的感知器最简单的学习流程:
输入数据->求和得到z->通过激活函数等到分类结果->分类结果与正确结果不符则调整网络。
实现代码:

public class Perceptron {
    /**
     * 学习率
     */
    private final float learnRate;
    /**
     * 学习次数
     */
    private final int studyCount;
    /**
     * 阈值
     */
    private static float e;
    /**
     * 权值
     * 因为判断整数正负只需要一条输入,所以这里只有一个权值,多条输入可以设置为数组
     */
    private static float w;
    /**
     * 每次学习的正确率
     */
    private static float[] correctRate;
    /**
     * 构造函数初始化学习率,学习次数,权值、阈值初始化为0
     * @param learnRate 学习率(取值范围 0 < learnRate < 1)
     * @param studyCount 学习次数
     */
    public Perceptron(float learnRate, int studyCount) {
        this.learnRate = learnRate;
        this.studyCount = studyCount;
        this.e = 0;
        this.w = 0;
        this.correctRate = new float[studyCount];
    }
    /**
     * 学习函数,samples 是一个包含输入数据和分类结果的二维数组,
     * samples[][0] 表示输入数据
     * samples[][1] 表示正确的分类结果
     * @param samples 训练数据
     */
    public void fit(int[][] samples) {
        int sampleLength = samples.length;
        for(int i = 0 ; i < studyCount ; i ++) {
            int errorCount = 0;
            for (int[] sample : samples) {
                float update = learnRate * (sample[1]-predict(sample[0]));
            //更新权值、阈值
            w += update * sample[0];
            e += update;

            //计算错误次数
            if (update != 0)
                errorCount++;
        }
        //计算此次学习的正确率
        correctRate[i] = 1 - errorCount * 1.0f / sampleLength;
    }
}
/**
 * 求和函数,模拟求和结点操作 输入数据 * 权值
 * @param num 输入数据
 * @return 求和结果 z
 */
private float sum(int num) {
    return num * w + e;
}
/**
 * 激活函数,通过求和结果 z 和阈值 e 进行判断
 * @param num 输入数据
 * @return 分类结果
 */
public int predict(int num) {
    return sum(num) >= 0 ? 1 : -1;
}
/**
 * 打印正确率
 */
public void printCorrectRate() {
    for (int i = 0 ; i < studyCount ; i ++)
        System.out.printf("第%d次学习的正确率 -> %.2f%%\n",i + 1,correctRate[i] * 100);
}
/**
 * 生成训练数据
 * @return 训练数据
 */
private static int[][] genStudyData() {
    //这里我们取 -100 ~ 100 之间的整数,大于0的设为模式 y = 1,反之为 y = -1
    int[][] data = new int[201][2];
    for(int i = -100 , j = 0; i <= 100 ; i ++ , j ++) {
        data[j][0] = i;
        data[j][1] = i >= 0 ? 1 : -1;
    }
    return data;
}
/**
 * 生成训练数据
 * @return 训练数据
 */
private static int[][] genStudyData2() {
    //这里我们取 1~250 之间的整数,大于125的设为模式 y = 1,反之为 y = -1
    int[][] data = new int[250][2];
    for(int i = 1 , j = 0; i <= 250 ; i ++ , j ++) {
        data[j][0] = i;
        data[j][1] = i >= 125 ? 1 : -1;
    }
    return data;
}

public static void main(String[] args) {
    //这里的学习率和训练次数可以根据情况人为调整
    Perceptron perceptron = new Perceptron(0.4f,5000);
    perceptron.fit(genStudyData());
    perceptron.printCorrectRate();
    System.out.println("e:"+e);
    System.out.println("w:"+w);
    System.out.println("correctRate:"+correctRate[499]);
    System.out.println(perceptron.predict(-1));
    System.out.println(perceptron.predict(126));
}

}

上面的代码实现的功能(通过输入的数据进行学习后,得到权值和阈值,最终判断;一个数是否为正数,正数为1,负数为-1)输出:
在这里插入图片描述

上面的代码很简单,下面我们来对我们的H2O进行改造,在H2O中加入上述算法并进行训练得到权值和阈值:
a.设置模型超参数(学习次数和学习率)
在这里插入图片描述

代码操作一如上面所讲
定义:
在这里插入图片描述

接收:
在这里插入图片描述

b.算法改动:
在这里插入图片描述

全部代码(MyAlgorithm.java):

package hex.my;

import hex.ModelBuilder;
import hex.ModelCategory;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.util.TwoDimTable;

import java.util.Arrays;

public class MyAlgorithm extends ModelBuilder<MyAlgorithmModel, MyAlgorithmModel.MyAlgorithmParameters, MyAlgorithmModel.MyAlgorithmOutput> {
    //从http请求中调用
    public MyAlgorithm(MyAlgorithmModel.MyAlgorithmParameters parms) {
        super(parms);
        init(false);
    }

    public MyAlgorithm(boolean startup_once) {
        super(new MyAlgorithmModel.MyAlgorithmParameters(), startup_once);
    }
//是否能够下载pojo
//@Override public boolean havePojo() { return true; }
//是否能够下载mojo
//@Override public boolean haveMojo() { return false; }

/**
 * 初始化ModelBuilder,验证所有参数并准备训练框架。此调用应在子类中重写
 *
 * @param expensive
 */
@Override
public void init(boolean expensive) {
    super.init(expensive);
    //处理响应列,页面中设置的响应列,此算法是否具有响应列
    if (_response != null) {
        //判断是否为分类列,如果不是抛出错误信息
        if (!_response.isCategorical()) error("_response", "Response must be a categorical column");
            //如果列仅包含常量值且不包含NAs,则为True至少有两个
        else if (_response.isConst())
            error("_response", "Response must have at least two unique categorical levels");
    }
    //如果没有错误发生,checkMemoryFootPrint确保最终的模型能够放入内存。注意:不应重写此方法
    // (而是重写checkMemoryFootPrint_impl)。
    //破坏第三方实现并不是“最终”声明。如果有必要的话,将来可能会宣布它是最终的
    if (expensive && error_count() == 0) checkMemoryFootPrint();
}

@Override
protected Driver trainModelImpl() {
    return new MyAlgorithmDriver();
}

@Override
public ModelCategory[] can_build() {
    return new ModelCategory[0];
}

//是否作为监督学习算法
@Override
public boolean isSupervised() {
    return false;
}

/* F/J{@link CountedCompleter}上支持优先级的简单包装排队。F/J队列是简单无序的(而且非常轻)排队。
 *然而,我们经常需要优先事项来避免僵局和提高有效吞吐量(例如未能快速响应{@linkTaskGetKey}可以阻塞
 * 整个节点,因为缺少数据)。所以每次尝试做低优先级的F/J工作都是从尝试工作并排出高优先级队列。
 * */
private class MyAlgorithmDriver extends Driver {
    @Override
    public void computeImpl() {
        MyAlgorithmModel model = null;
        try {
            init(true);
            // The model to be built  待建模型   _job._result结果键
            model = new MyAlgorithmModel(_job._result, _parms, new MyAlgorithmModel.MyAlgorithmOutput(MyAlgorithm.this));
            /*通过{@code job{u key}编写锁{@code this.{u key},并删除任何先前的映射。
             *如果密钥已锁定,则抛出IAE。被作业密钥锁定
             */
            model.delete_and_lock(_job);

            // Run the main Example Loop运行主示例循环,迭代足够后停止
            // Stop after enough iterations
            for (; model._output._iterations < _parms._max_iterations; model._output._iterations++) {
                if (stop_requested()) break; // Stopped/cancelled
                //从界面得到自定义的学习率
                float learn_rate = _parms._learn_rate;
                //如果用键调用doAll,则调用在密钥的主节点上为每个密钥生成
                Max max = new Max(_job._key, learn_rate).doAll(_parms.train());
                double[] maxs = max._maxs;
                //返回本次迭代的正确率
                float correctRate = max.correctRate;
                //阈值
                float e = max.e;
                //权值
                float w = max.w;
                // Fill in the model 填充模型,并将其返回到界面中
                model._output._maxs = maxs;
                model._output._w = w;
                model._output._e = e;
                model._output._correct_rate = correctRate;
                double[] dom=new double[3];
                dom[0]=(double)w;
                dom[1]=(double)e;
                dom[2]=(double)correctRate;
                /*TwoDimTable的构造函数(R行、C列)
                *@param table header表格标题
                *@param table description表说明
                *@param行标题R-dim行标题数组
                *@param colHeaders列标题的C-dim数组
                *@param colTypes列类型的C-dim数组
                *@param colFormats每个列都有printf格式字符串的C-dim数组
                *@param colHeaderForRowHeaders行标题的列标题
                *@param strCellValues String[R][C]字符串单元格值的数组,可以为空(例如,可以提供String[R][])
                *@param dblCellValues double[R][C]双单元格值数组可以为空(用emptyDuble标记-用double[R][]初始化时发生)
                */
                model._output._dom = dom;

                model.update(_job);   // Update model in K/V store  K/V存储中的更新模型
                _job.update(1);// One unit of work  一个工作单元,更新为此任务完成的新任务

                StringBuilder sb = new StringBuilder();
                sb.append("Example: iter: ").append(model._output._iterations);
                //Log.info(sb);
            }
        } finally {
            if (model != null) model.unlock(_job);
        }
    }
}

/**
 * MRTask map/reduce 分布式计算,将数据集分成块来进行map/reduce计算
 * Find max per-column  每列查找最大值
 */
private static class Max extends MRTask<Max> {
    final protected Key<Job> _jobKey;
    // IN
    // OUT
    double[] _maxs;
    /**
     * 学习率
     */
    private float learnRate;
    /**
     * 阈值
     */
    private float e;
    /**
     * 权值
     * 因为判断整数正负只需要一条输入,所以这里只有一个权值,多条输入可以设置为数组
     */
    private float w;
    /**
     * 每次学习的正确率
     */
    private float correctRate;

    private Max(Key<Job> jobKey, float learn_rate) {
        _jobKey = jobKey;
        learnRate = learn_rate;
    }

    @Override
    public void map(Chunk[] cs) {
        //Chunk,H2O中的块数据(数据压缩方式),便于集群处理,支持类似于数组的API操作
        //根据传入的数据列数创建一个数组
        _maxs = new double[cs.length];
        //初始化_maxs数组,将double类型的最小值填充入数组中
        Arrays.fill(_maxs, -Double.MAX_VALUE);
        //map函数每一列的最大值使用块相对行号加载{@code double}值。返回Double.NaN
        //如果缺少值, 返回给定行的双精度值,如果缺少该值,则返回NaN  
        // cs[col].atd(row)这个可以得到上传的二维表的值,相当于一个二维数组
        for (int col = 0; col < cs.length; col++)
            for (int row = 0; row < cs[col]._len; row++)
                _maxs[col] = Math.max(_maxs[col], cs[col].atd(row));
        /**
         * 得到每行的每一列值,用于更新权值、阈值 
         * 学习函数,cs 是一个包含输入数据和分类结果的二维数组,
         * cs[][0] 表示输入数据
         * cs[][1] 表示正确的分类结果
         * @param cs 训练数据
         */
            int errorCount = 0;
            for (int r = 0; r < cs[0]._len; r++) {
                float update = (float) (learnRate * (cs[1].atd(r) - predict((int) cs[0].atd(r))));
                System.out.println("第一列第"+r+"行:"+cs[0].atd(r));
                System.out.println("第二列第"+r+"行:"+cs[1].atd(r));
                //更新权值、阈值
                w += update * (int) cs[0].atd(r);
                e += update;
                //计算错误次数
                if (update != 0)
                    errorCount++;
            }
            //计算此次学习的正确率
            correctRate = 1 - errorCount * 1.0f / cs[0]._len;
    }
    /**
     * 求和函数,模拟求和结点操作 输入数据 * 权值
     *
     * @param num 输入数据
     * @return 求和结果 z
     */
    private float sum(int num) {
        return num * w + e;
    }
    /**
     * 激活函数,通过求和结果 z 和阈值 e 进行判断
     *
     * @param num 输入数据
     * @return 分类结果
     */
    public int predict(int num) {
        return sum(num) >= 0 ? 1 : -1;
    }
    @Override
    public void reduce(Max that) {
        for (int col = 0; col < _maxs.length; col++)
            _maxs[col] = Math.max(_maxs[col], that._maxs[col]);
    }
}

}

c.数据输出:
在这里插入图片描述

在这里插入图片描述

与这个结果一致
在这里插入图片描述

模型训练过程到此基本结束,接下来讲是设置响应列和使用训练好的模型进行预测。

待续。。。。。。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值