DL4J实战之二:鸢尾花分类,java程序设计实验实训教程蔡木生

<maven.compiler.target>8</maven.compiler.target>

com.bolingcavalry

commons

${project.version}

org.projectlombok

lombok

org.nd4j

${nd4j.backend}

ch.qos.logback

logback-classic

  • 上述pom.xml有一处需要注意的地方,就是${nd4j.backend}参数的值,该值在决定了后端线性代数计算是用CPU还是GPU,本篇为了简化操作选择了CPU(因为个人的显卡不同,代码里无法统一),对应的配置就是nd4j-native;

  • 源码全部在Iris.java文件中,并且代码中已添加详细注释,就不再赘述了:

package com.bolingcavalry.classifier;

import com.bolingcavalry.commons.utils.DownloaderUtility;

import lombok.extern.slf4j.Slf4j;

import org.datavec.api.records.reader.RecordReader;

import org.datavec.api.records.reader.impl.csv.CSVRecordReader;

import org.datavec.api.split.FileSplit;

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;

import org.deeplearning4j.nn.conf.MultiLayerConfiguration;

import org.deeplearning4j.nn.conf.NeuralNetConfiguration;

import org.deeplearning4j.nn.conf.layers.DenseLayer;

import org.deeplearning4j.nn.conf.layers.OutputLayer;

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

import org.deeplearning4j.nn.weights.WeightInit;

import org.deeplearning4j.optimize.listeners.ScoreIterationListener;

import org.nd4j.evaluation.classification.Evaluation;

import org.nd4j.linalg.activations.Activation;

import org.nd4j.linalg.api.ndarray.INDArray;

import org.nd4j.linalg.dataset.DataSet;

import org.nd4j.linalg.dataset.SplitTestAndTrain;

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;

import org.nd4j.linalg.learning.config.Sgd;

import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.File;

/**

  • @author will (zq2599@gmail.com)

  • @version 1.0

  • @description: 鸢尾花训练

  • @date 2021/6/13 17:30

*/

@SuppressWarnings(“DuplicatedCode”)

@Slf4j

public class Iris {

public static void main(String[] args) throws Exception {

//第一阶段:准备

// 跳过的行数,因为可能是表头

int numLinesToSkip = 0;

// 分隔符

char delimiter = ‘,’;

// CSV读取工具

RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);

// 下载并解压后,得到文件的位置

String dataPathLocal = DownloaderUtility.IRISDATA.Download();

log.info(“鸢尾花数据已下载并解压至 : {}”, dataPathLocal);

// 读取下载后的文件

recordReader.initialize(new FileSplit(new File(dataPathLocal,“iris.txt”)));

// 每一行的内容大概是这样的:5.1,3.5,1.4,0.2,0

// 一共五个字段,从零开始算的话,标签在第四个字段

int labelIndex = 4;

// 鸢尾花一共分为三类

int numClasses = 3;

// 一共150个样本

int batchSize = 150; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

// 加载到数据集迭代器中

DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);

DataSet allData = iterator.next();

// 洗牌(打乱顺序)

allData.shuffle();

// 设定比例,150个样本中,百分之六十五用于训练

SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training

// 训练用的数据集

DataSet trainingData = testAndTrain.getTrain();

// 验证用的数据集

DataSet testData = testAndTrain.getTest();

// 指定归一化器:独立地将每个特征值(和可选的标签值)归一化为0平均值和1的标准差。

DataNormalization normalizer = new NormalizerStandardize();

// 先拟合

normalizer.fit(trainingData);

// 对训练集做归一化

normalizer.transform(trainingData);

// 对测试集做归一化

normalizer.transform(testData);

// 每个鸢尾花有四个特征

final int numInputs = 4;

// 共有三种鸢尾花

int outputNum = 3;

// 随机数种子

long seed = 6;

//第二阶段:训练

log.info(“开始配置…”);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

.seed(seed)

.activation(Activation.TANH) // 激活函数选用标准的tanh(双曲正切)

.weightInit(WeightInit.XAVIER) // 权重初始化选用XAVIER:均值 0, 方差为 2.0/(fanIn + fanOut)的高斯分布

.updater(new Sgd(0.1)) // 更新器,设置SGD学习速率调度器

.l2(1e-4) // L2正则化配置

.list() // 配置多层网络

.layer(new DenseLayer.Builder().nIn(numInputs).nOut(3) // 隐藏层

自我介绍一下,小编13年上海交大毕业,曾经在小公司待过,也去过华为、OPPO等大厂,18年进入阿里一直到现在。

深知大多数Java工程师,想要提升技能,往往是自己摸索成长或者是报班学习,但对于培训机构动则几千的学费,着实压力不小。自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前!

因此收集整理了一份《2024年Java开发全套学习资料》,初衷也很简单,就是希望能够帮助到想自学提升又不知道该从何学起的朋友,同时减轻大家的负担。
img
img
img
img
img
img

既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上Java开发知识点,真正体系化!

由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新

如果你觉得这些内容对你有帮助,可以添加V获取:vip1024b (备注Java)
img

复习的面试资料

这些面试全部出自大厂面试真题和面试合集当中,小编已经为大家整理完毕(PDF版)

  • 第一部分:Java基础-中级-高级

image

  • 第二部分:开源框架(SSM:Spring+SpringMVC+MyBatis)

image

  • 第三部分:性能调优(JVM+MySQL+Tomcat)

image

  • 第四部分:分布式(限流:ZK+Nginx;缓存:Redis+MongoDB+Memcached;通讯:MQ+kafka)

image

  • 第五部分:微服务(SpringBoot+SpringCloud+Dubbo)

image

  • 第六部分:其他:并发编程+设计模式+数据结构与算法+网络

image

进阶学习笔记pdf

  • Java架构进阶之架构筑基篇(Java基础+并发编程+JVM+MySQL+Tomcat+网络+数据结构与算法

image

  • Java架构进阶之开源框架篇(设计模式+Spring+SpringMVC+MyBatis

image

image

image

  • Java架构进阶之分布式架构篇 (限流(ZK/Nginx)+缓存(Redis/MongoDB/Memcached)+通讯(MQ/kafka)

image

image

image

  • Java架构进阶之微服务架构篇(RPC+SpringBoot+SpringCloud+Dubbo+K8s)

image

image

一个人可以走的很快,但一群人才能走的更远。不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人,都欢迎扫码加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!
img

mg-uEGw0AhH-1712784476464)]

[外链图片转存中…(img-cn9855vO-1712784476464)]

  • Java架构进阶之微服务架构篇(RPC+SpringBoot+SpringCloud+Dubbo+K8s)

[外链图片转存中…(img-trNMBWJl-1712784476464)]

[外链图片转存中…(img-ifId5Hzx-1712784476465)]

一个人可以走的很快,但一群人才能走的更远。不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人,都欢迎扫码加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!
[外链图片转存中…(img-xSdLCVtY-1712784476465)]

  • 9
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值