Java实现LSTM和GRU做分类(以IRIS数据集为例)

本文介绍如何在Java项目中利用GRU神经网络对IRIS数据集进行分类。通过修改lipiji的开源代码,创建一个新的DEMO,将Iris数据集的类别转化为0,1,2,并划分为训练集和测试集。实验结果显示,45个测试样本中有44个正确分类,但未能达到100%准确率。" 132411806,19721532,使用Iceberg Java API操作Hive Catalog,"['Hive', 'Java', '数据库开发', '数据存储', '大数据处理']
摘要由CSDN通过智能技术生成

笔者想在JAVA项目中做机器学习的分类想使用循环神经网络的时候苦于没有找到开源的代码,最后终于找到lipiji所写的LSTM和GRU,项目GitHub链接在这:项目GitHub地址,但是这个项目的demo只是简单的做了一个文本序列的预测,无法达到自己做分类的目的,于是笔者新写了一个demo来实现分类的目的,这里所使用的数据集是Iris。Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。(来源:百度百科)点击下载Iris数据集 没有积分的也可以自己去找不需要积分的数据集。

数据预处理:首先将数据集里的花的类别修改成0,1,2三类,然后将每类中取15条数据共45条做测试集,余下105个做训练集分别存在两个文件中。新写一个类放在com.lipiji.mllib.rnn.gru包中,这里的输出层有三个节点,代表三个类别。笔者这里采用的GRU实验,要做LSTM的话将GRU类改成Cell类即可。测试代码如下:

package com.lipiji.mllib.rnn.gru;

import com.lipiji.mllib.layers.MatIniter;
import com.lipiji.mllib.rnn.lstm.Cell;
import com.lipiji.mllib.rnn.lstm.LSTM;
import com.lipiji.mllib.utils.LossFunction;
import org.jblas.DoubleMatrix;
  • 3
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值