DL4J实战之二:鸢尾花分类

蜂信物联FastBee平台https://gitee.com/beecue/fastbee

阿里资料开源项目https://gitee.com/vip204888

百度低代码前端框架https://gitee.com/baidu/amis

OpenHarmony开源项目https://gitcode.com/openharmony

仓颉编程语言开放项目https://gitcode.com/Cangjie

import org.nd4j.evaluation.classification.Evaluation;

import org.nd4j.linalg.activations.Activation;

import org.nd4j.linalg.api.ndarray.INDArray;

import org.nd4j.linalg.dataset.DataSet;

import org.nd4j.linalg.dataset.SplitTestAndTrain;

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;

import org.nd4j.linalg.learning.config.Sgd;

import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.File;

/**

  • @author will (zq2599@gmail.com)

  • @version 1.0

  • @description: 鸢尾花训练

  • @date 2021/6/13 17:30

*/

@SuppressWarnings(“DuplicatedCode”)

@Slf4j

public class Iris {

public static void main(String[] args) throws Exception {

//第一阶段:准备

// 跳过的行数,因为可能是表头

int numLinesToSkip = 0;

// 分隔符

char delimiter = ‘,’;

// CSV读取工具

RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);

// 下载并解压后,得到文件的位置

String dataPathLocal = DownloaderUtility.IRISDATA.Download();

log.info(“鸢尾花数据已下载并解压至 : {}”, dataPathLocal);

// 读取下载后的文件

recordReader.initialize(new FileSplit(new File(dataPathLocal,“iris.txt”)));

// 每一行的内容大概是这样的:5.1,3.5,1.4,0.2,0

// 一共五个字段,从零开始算的话,标签在第四个字段

int labelIndex = 4;

// 鸢尾花一共分为三类

int numClasses = 3;

// 一共150个样本

int batchSize = 150; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

// 加载到数据集迭代器中

DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);

DataSet allData = iterator.next();

// 洗牌(打乱顺序)

allData.shuffle();

// 设定比例,150个样本中,百分之六十五用于训练

SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training

// 训练用的数据集

DataSet trainingData = testAndTrain.getTrain();

// 验证用的数据集

DataSet testData = testAndTrain.getTest();

// 指定归一化器:独立地将每个特征值(和可选的标签值)归一化为0平均值和1的标准差。

DataNormalization normalizer = new NormalizerStandardize();

// 先拟合

normalizer.fit(trainingData);

// 对训练集做归一化

normalizer.transform(trainingData);

// 对测试集做归一化

normalizer.transform(testData);

// 每个鸢尾花有四个特征

final int numInputs = 4;

// 共有三种鸢尾花

int outputNum = 3;

// 随机数种子

long seed = 6;

//第二阶段:训练

log.info(“开始配置…”);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

.seed(seed)

.activation(Activation.TANH) // 激活函数选用标准的tanh(双曲正切)

.weightInit(WeightInit.XAVIER) // 权重初始化选用XAVIER:均值 0, 方差为 2.0/(fanIn + fanOut)的高斯分布

.updater(new Sgd(0.1)) // 更新器,设置SGD学习速率调度器

.l2(1e-4) // L2正则化配置

.list() // 配置多层网络

.layer(new DenseLayer.Builder().nIn(numInputs).nOut(3) // 隐藏层

.build())

.layer(new DenseLayer.Builder().nIn(3).nOut(3) // 隐藏层

.build())

.layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) // 损失函数:负对数似然

.activation(Activation.SOFTMAX) // 输出层指定激活函数为:SOFTMAX

.nIn(3).nOut(outputNum).build())

.build();

// 模型配置

MultiLayerNetwork model = new MultiLayerNetwork(conf);

// 初始化

model.init();

// 每一百次迭代打印一次分数(损失函数的值)

model.setListeners(new ScoreIterationListener(100));

long startTime = System.currentTimeMillis();

log.info(“开始训练”);

// 训练

for(int i=0; i<1000; i++ ) {

model.fit(trainingData);

}

log.info(“训练完成,耗时[{}]ms”, System.currentTimeMillis()-startTime);

// 第三阶段:评估

// 在测试集上评估模型

Evaluation eval = new Evaluation(numClasses);

INDArray output = model.output(testData.getFeatures());

eval.eval(testData.getLabels(), output);

log.info(“评估结果如下\n” + eval.stats());

}

}

  • 编码完成后,运行main方法,可见顺利完成训练并输出了评估结果,还有混淆矩阵用于辅助分析:

在这里插入图片描述

  • 至此,咱们的第一个实战就完成了,通过经典实例体验的DL4J训练和评估的常规步骤,对重要API也有了初步认识,接下来会继续实战,接触到更多的经典实例;

你不孤单,欣宸原创一路相伴

  1. Java系列

  2. Spring系列

  3. Docker系列

  4. kubernetes系列

更多:Java进阶核心知识集

包含:JVM,JAVA集合,网络,JAVA多线程并发,JAVA基础,Spring原理,微服务,Zookeeper,Kafka,RabbitMQ,Hbase,MongoDB,Cassandra,设计模式,负载均衡,数据库,一致性哈希,JAVA算法,数据结构,加密算法,分布式缓存等等

image

高效学习视频

创一路相伴

  1. Java系列

  2. Spring系列

  3. Docker系列

  4. kubernetes系列

更多:Java进阶核心知识集

包含:JVM,JAVA集合,网络,JAVA多线程并发,JAVA基础,Spring原理,微服务,Zookeeper,Kafka,RabbitMQ,Hbase,MongoDB,Cassandra,设计模式,负载均衡,数据库,一致性哈希,JAVA算法,数据结构,加密算法,分布式缓存等等

[外链图片转存中…(img-xfIOwGg0-1725130604605)]

高效学习视频

  • 20
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值