笔者想在JAVA项目中做机器学习的分类想使用循环神经网络的时候苦于没有找到开源的代码,最后终于找到lipiji所写的LSTM和GRU,项目GitHub链接在这:项目GitHub地址,但是这个项目的demo只是简单的做了一个文本序列的预测,无法达到自己做分类的目的,于是笔者新写了一个demo来实现分类的目的,这里所使用的数据集是Iris。Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。(来源:百度百科)点击下载Iris数据集 没有积分的也可以自己去找不需要积分的数据集。
数据预处理:首先将数据集里的花的类别修改成0,1,2三类,然后将每类中取15条数据共45条做测试集,余下105个做训练集分别存在两个文件中。新写一个类放在com.lipiji.mllib.rnn.gru包中,这里的输出层有三个节点,代表三个类别。笔者这里采用的GRU实验,要做LSTM的话将GRU类改成Cell类即可。测试代码如下:
package com.lipiji.mllib.rnn.gru;
import com.lipiji.mllib.layers.MatIniter;
import com.lipiji.mllib.rnn.lstm.Cell;
import com.lipiji.mllib.rnn.lstm.LSTM;
import com.lipiji.mllib.utils.LossFunction;
import org.jblas.DoubleMatrix;