深度学习-利用卷积网络识别动物

本文介绍了利用深度学习框架deeplearning4j结合Spark进行动物图片分类,数据包括熊、鸭、鹿和龟四种动物。尽管分类效果不理想,但提出了改进策略,如增加样本量、数据预处理、增强训练次数、调整模型配置及学习率等超参数。
摘要由CSDN通过智能技术生成

深度学习框架很多,我讲的是deeplearning4j,因为它能和spark结合,代码是java,虽然我java也很烂

数据源是4种动物的照片,有熊,鸭,鹿,龟

示例的分类结果不是很理想,建议我们通过以下方式提高:

1.增加照片数量

2.进行更多的数据预处理

3.增加训练次数,所有数据都训练完了才叫一次训练

4.调整模型配置

5.调整学习率,更新器,激活函数,损失函数,正则化参数等

这也是实战中深度学习问题需要面对的问题
贴出代码,dl4j几乎没什么注释,我只能根据相关资料来分析,有不对的地方还望指出

public class AnimalsClassification {
    protected static final Logger log = LoggerFactory.getLogger(AnimalsClassification.class);//通过反射获取日志名
    protected static int height = 100;//照片是100*100
    protected static int width = 100;
    protected static int channels = 3;//过滤器数量,就是输入层和几个过滤器连接,每个过滤器都按不同规则对输入层进行处理
    protected static int numExamples = 80;//80个样本
    protected static int numLabels = 4;//4个类别
    protected static int batchSize = 20;//每次处理20个样本,这80个样本分4批训练完,参数会更新4次,80个样本训练完了才是一步训练

    protected static long seed = 42;
    protected static Random rng = new Random(seed);//随机数生成器
    protected static int listenerFreq = 1;//参数更新一次,就打印一次score
    protected static int iterations = 1;//每步训练的迭代次数,正常来讲4批是一次,但是我们可以增加迭代次数让每步训练迭代更多次
    protected static int epochs = 50;//训练步数,够多的,时间应该很长
    protected static double splitTrainTest = 0.8;//80%训练,20%测试
    protected static int nCores = 2;//装载数据的队列数
    protected static boolean save = false;//不存储

    protected static String modelType = "AlexNet"; // LeNet, AlexNet or Custom but you need to fill it out//使用AlexNet网络


    public void run(String[] args) throws Exception {

        log.info("Load data....");
        /**cd
         * Data Setup -> organize and limit data file paths:
         *  - mainPath = path to image files//图片路径
         *  - fileSplit = define basic dataset split with limits on format//定义数据划分
         *  - pathFilter = define additional file load filter to limit size and balance batch content//定义额外的文件加载过滤器用来限制大小平衡批内容
         **/
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();//按文件名产生标签0,1,2,3
        File mainPath = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/animals/");//图片主路径
        FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);//把所有图片弄成一个经过shuffle的数组
        BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSize);//平衡每个batch中label的数量


        /**
         * Data Setup -> train test split//划分训练和测试集
         *  - inputSplit = define train and test split
         **/
        InputSplit[] inputSplit = fileSplit.sample(pathFilter, numExamples * (1 + splitTrainTest), numExamples * (1 - splitTrainTest));//第二个参数是144,第三个是16,这里不是普通的80%,20%,内部有自己的策略,所以最终训练测试条数也跟预想的不一致

        InputSplit trainData = inputSplit[0];//训练数据68条
        InputSplit testData = inputSplit[1];//测试数据8条

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值