基于Fruits-360水果数据集的TensorSpace神经网络3D可视化(水果识别可视化)

1. 简介

TensorSpace 是一套用于构建神经网络3D可视化应用的框架。开发者可以使用类 Keras 风格的 TensorSpace API, 轻松创建可视化网络、加载神经网络模型并在浏览器中基于已加载的模型进行3D可交互呈现。通过使用 TensorSpace,可以更直观的观察并理解基于 TensorFlow、Keras 或者 TensorFlow.js 等开发的神经网络模型。
TensorSpace 使用开发流程

图1 ensorSpace 使用开发流程

2. 项目环境要求

TensorSpace官网给出的适配TensorSpace.js 浏览器有:

浏览器版本号
Chrome64+
Firefox58+
Safari12+

TensorSpace.js 需要以下依赖库:

名称版本号
TensorFlow.js1.0.0+
Three.jsr101+
Tween.js17.2.0+

3. 安装

  1. 第一步:下载依赖库
依赖库文件备注
TensorFlow.jstf.min.js
Three.jsthree.min.js
Tween.jstween.min.js
TrackballControlsTrackballControls.js
  1. 第二步:下载 TensorSpace
    下载链接: Github.
<!-- 将”VERSION”替换成需要的版本 -->
<script src="tensorspace.min.js"></script>
  1. 第三步:在页面中引入库文件
<script src="tf.min.js"></script>
<script src="three.min.js"></script>
<script src="Tween.min.js"></script>
<script src="TrackballControls.js"></script>
<script src="tensorspace.min.js"></script>

4. 使用

  1. 第一步:模型预处理
  • 在我之前写的博客基于Fruits-360数据集构建CNN进行水果识别实验中,我们已经完成训练并得到了的.h5文件,接下来我们要做的就是将神经网络模型通过一系列过程转换至 TensorSpace 可以使用的格式,而这一过程就被称为模型预处理
    模型的预处理
图2 模型的预处理
  • 这里我们使用到的是一个名为TensorSpace-Converter的模型转换工具,它可以帮助我们快速完成 TensorSpace 预处理过程。由于我之前的实验使用的是Keras,并用它训练得到了一个Keras模型,并且其模型结构和权重保存在一个HDF5文件里,所以我们编写一个bash脚本来进行模型转化。
// An highlighted block
#!/usr/bin/env bash
tensorspacejs_converter \
    --input_model_from="keras" \
    --input_model_format="topology_weights_combined" \
    --output_layer_names="conv2d_1,max_pooling2d_1,conv2d_2,max_pooling2d_2,conv2d_3,max_pooling2d_3,conv2d_4,max_pooling2d_4,flatten_1,dense_1,dense_2" \
    ./model/model_demo.h5 \
    ./convertedModel\

  • 以上 TensorSpace-Converter预处理脚本将会在 convertedModel 文件夹中生成经过预处理的模型:
    (1)一份 model.json 文件:包含所得到的模型结构信息(包括中间层输出)。
    (2)一些权重文件:包含模型训练所得到的权重信息。权重文件的数量取决于模型的结构。
    模型转换后
图3 模型转后的结果

在这里插入图片描述

图4 将模型 Layer 名取出并设置 output_layer_names
  • 我的实验中构建的神经网络模型结构如下所示:
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 64, 64, 16)        448       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 32, 32, 16)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 32)        4640      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 16, 16, 32)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 32)        9248      
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 8, 8, 32)          0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 8, 8, 64)          18496     
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 4, 4, 64)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 1024)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               262400    
_________________________________________________________________
dense_2 (Dense)              (None, 131)               33667     
=================================================================
Total params: 328,899
Trainable params: 328,899
Non-trainable params: 0
_________________________________________________________________
  1. 第二步:使用 TensorSpace 可视化模型
  • 载入并可视化
    通过 TensorSpace API 构建 TensorSpace 可视化模型。
let modelContainer = document.getElementById("container");
            let model = new TSP.models.Sequential( modelContainer );

    		model.add( new TSP.layers.RGBInput({ shape: [64, 64, 3] }) );
    		model.add( new TSP.layers.Conv2d({ kernelSize: 3, filters: 16, strides: 1 }));
    		model.add( new TSP.layers.Pooling2d({ poolSize: [2, 2], strides: [2, 2] }) );
    		model.add( new TSP.layers.Conv2d({ kernelSize: 3, filters: 32, strides: 1 }) );
    		model.add( new TSP.layers.Pooling2d({ poolSize: [2, 2], strides: [2, 2] }) );
            model.add( new TSP.layers.Conv2d({ kernelSize: 3, filters: 32, strides: 1 }) );
    		model.add( new TSP.layers.Pooling2d({ poolSize: [2, 2], strides: [2, 2] }) );
    		model.add( new TSP.layers.Conv2d({ kernelSize: 3, filters: 64, strides: 1 }) );
    		model.add( new TSP.layers.Pooling2d({ poolSize: [2, 2], strides: [2, 2] }) );
    		model.add( new TSP.layers.Dense({ units: 256 }) );
    		model.add( new TSP.layers.Dense({ units: 131 }) );
    		model.add( new TSP.layers.Output1d({
                units :  131 ,
    			outputs: ['Apple Braeburn', 'Apple Crimson Snow', 'Apple Golden 1', 'Apple Golden 2', 'Apple Golden 3',
                    'Apple Granny Smith', 'Apple Pink Lady', 'Apple Red 1', 'Apple Red 2', 'Apple Red 3',
                    'Apple Red Delicious', 'Apple Red Yellow 1', 'Apple Red Yellow 2', 'Apricot', 'Avocado',
                    'Avocado ripe', 'Banana', 'Banana Lady Finger', 'Banana Red', 'Beetroot',
                    'Blueberry', 'Cactus fruit', 'Cantaloupe 1', 'Cantaloupe 2', 'Carambula',
                    'Cauliflower', 'Cherry 1', 'Cherry 2', 'Cherry Rainier', 'Cherry Wax Black',
                    'Cherry Wax Red', 'Cherry Wax Yellow', 'Chestnut', 'Clementine', 'Cocos',
                    'Corn', 'Corn Husk', 'Cucumber Ripe', 'Cucumber Ripe 2', 'Dates',
                    'Eggplant', 'Fig', 'Ginger Root', 'Granadilla', 'Grape Blue',
                    'Grape Pink', 'Grape White', 'Grape White 2', 'Grape White 3', 'Grape White 4',
                    'Grapefruit Pink', 'Grapefruit White', 'Guava', 'Hazelnut', 'Huckleberry',
                    'Kaki', 'Kiwi', 'Kohlrabi', 'Kumquats', 'Lemon',
                    'Lemon Meyer', 'Limes', 'Lychee', 'Mandarine',
                    'Mango', 'Mango Red', 'Mangostan', 'Maracuja', 'Melon Piel de Sapo',
                    'Mulberry', 'Nectarine', 'Nectarine Flat', 'Nut Forest', 'Nut Pecan',
                    'Onion Red', 'Onion Red Peeled', 'Onion White', 'Orange', 'Papaya',
                    'Passion Fruit', 'Peach', 'Peach 2', 'Peach Flat', 'Pear',
                    'Pear 2', 'Pear Abate', 'Pear Forelle', 'Pear Kaiser', 'Pear Monster',
                    'Pear Red', 'Pear Stone', 'Pear Williams', 'Pepino', 'Pepper Green',
                    'Pepper Orange', 'Pepper Red', 'Pepper Yellow', 'Physalis', 'Physalis with Husk',
                    'Pineapple', 'Pineapple Mini', 'Pitahaya Red', 'Plum', 'Plum 2',
                    'Plum 3', 'Pomegranate', 'Pomelo Sweetie', 'Potato Red', 'Potato Red Washed',
                    'Potato Sweet', 'Potato White', 'Quince', 'Rambutan', 'Raspberry',
                    'Redcurrant', 'Salak', 'Strawberry', 'Strawberry Wedge', 'Tamarillo',
                    'Tangelo', 'Tomato 1', 'Tomato 2', 'Tomato 3', 'Tomato 4',
                    'Tomato Cherry Red', 'Tomato Heart', 'Tomato Maroon', 'Tomato not Ripened', 'Tomato Yellow',
                    'Walnut', 'Watermelon']

    		}) );
  • 载入经过 TensorSpace-Converter 预处理的模型,然后将模型进行初始化:
model.load({
    			type: "tfjs",
    			url: "convertedModel/model.json",
    			onComplete: function() {
    				console.log( "\"Hello World!\" from TensorSpace Loader." );
    			}
    		});

    		model.init( function() {
                $.ajax({
    				url: "json/banana_107_100.json",
    				type: 'GET',
    				async: true,
    				dataType: 'json',
    				success: function (d) {
    					model.predict( d);
    					console.log( d);
    				}
    			});
            } );
  1. 可视化结果展示
    展示结果(以banana图像为例)可进行拖拽、展开、旋转,可以详细地了解到神经网络的每一层。
    可视化1
图5 可视化展示1

可视化2

图6 可视化展示2

可视化3

图7 可视化展示3

可视化4

图8 可视化展示4

预测结果

图9 可视化展示5

最后一张图(图9)展示的即为我们的TensorSpace可视化最终的预测结果,我们的输入为banana的一张图片(已转化为banana_107_100.json),可以看到最终的Output层将其输出为banana这一分类。

5. 结语

通过基于 TensorSpace 所开发的3D可视化神经网络模型实例,我们可以体验不同的可交互模型,包括但不限于:物体分类、物体探测、图片生成等。通过展示这些模型实例,我们能更好、更直观地体现 TensorSpace 的应用场景、操作方法以及展示效果。

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
1. 数据集介绍 fruits 360是一个开源的水果图像数据集,包含了75种不同的水果,共约8万张图片。每种水果的图片数量不同,最多的是苹果(约7,000张),最少的是柠檬(约200张)。数据集中的图片都是经过调整大小和中心裁剪的,大小为100x100像素。数据集中的每种水果都有多个变体,例如不同成熟度的香蕉、不同颜色的苹果等等。 2. 算法设计 本算法采用卷积神经网络(CNN)进行图像分类。CNN是一种特殊的神经网络,可以自动提取图像中的特征,并将其用于分类。CNN的核心是卷积层和池化层,可以有效地减少参数数量,从而避免过拟合现象。此外,本算法还采用了数据增强技术,对训练集进行随机旋转、翻转、缩放等操作,以增加模型的鲁棒性。 3. 算法实现 本算法使用PyTorch框架进行实现。具体实现过程如下: 3.1 数据预处理 将fruits 360数据集下载到本地,并将其分为训练集和测试集。使用PyTorch提供的transforms模块对数据进行预处理,包括调整大小、随机旋转、随机水平翻转、随机竖直翻转、随机裁剪等操作。为了防止过拟合,训练集还进行了随机缩放操作。最终得到了训练集和测试集的数据加载器。 3.2 网络设计 本算法采用了一个简单的卷积神经网络,包括3个卷积层、3个池化层和3个全连接层。卷积层的卷积核大小为3x3,步长为1,补零为1,激活函数为ReLU;池化层的池化核大小为2x2,步长为2;全连接层的输出大小为75,即水果的种类数。具体网络结构如下: Conv2d(3, 32, 3, padding=1) ReLU(inplace=True) MaxPool2d(2, 2) Conv2d(32, 64, 3, padding=1) ReLU(inplace=True) MaxPool2d(2, 2) Conv2d(64, 128, 3, padding=1) ReLU(inplace=True) MaxPool2d(2, 2) Flatten() Linear(128 * 12 * 12, 512) ReLU(inplace=True) Linear(512, 256) ReLU(inplace=True) Linear(256, 75) 3.3 模型训练 采用交叉熵损失函数和随机梯度下降(SGD)优化器进行模型训练。初始学习率为0.01,每20个epoch衰减一次为原来的0.1。训练过程中,每个epoch会计算训练集和测试集的损失和准确率,并将结果保存到日志文件中。 4. 实验结果 经过100个epoch的训练,本算法在测试集上的准确率达到了96.8%。部分预测结果如下图所示: ![image](https://github.com/ShiniuPython/fruit_classification/blob/master/result.png) 可以看到,本算法在大多数情况下都能正确识别水果的种类。但是有些水果的不同变体之间相似度较高,如橙子和柠檬,有时候难以区分。此外,本算法对于水果的形状、颜色等变化较大的情况下也有一定的识别误差。 5. 总结 本算法采用了卷积神经网络进行图像分类,通过数据增强技术提高了模型的鲁棒性。实验结果表明,本算法可以有效地识别大多数水果的种类。但是,对于一些相似度较高的水果和变化较大的水果,还需要进一步改进。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值