昇思25天学习打卡营第10天 |昇思MindSpore 基于 MobileNetv2 的垃圾分类实验

一、实验目的与环境

序号内容详细描述
1实验目的需要对以下知识点进行熟悉掌握:
- 垃圾分类应用代码的编写(Python 语言)
- Linux 操作系统的基本使用
- 使用 atc 命令进行模型转换的基本操作
2实验环境本案例支持 win_x86Linux 系统,CPU/GPU/Ascend 均可运行。在动手进行实践之前,确保您已经正确安装了 MindSpore。不同平台下的环境准备请参考《MindSpore 环境搭建实验手册》。

二、MobileNetv2 模型原理介绍

MobileNet 网络是由 Google 团队于 2017 年提出的专注于移动端、嵌入式或 IoT 设备的轻量级 CNN 网络。相比于传统的卷积神经网络,MobileNet 网络使用深度可分离卷积(Depthwise Separable Convolution)的思想,在准确率小幅度降低的前提下,大大减小了模型参数与运算量,并引入宽度系数 α 和分辨率系数 β 使模型满足不同应用场景的需求。

由于 MobileNet 网络中 ReLU 激活函数处理低维特征信息时会存在大量的丢失,所以 MobileNetV2 网络提出使用倒残差结构(Inverted residual block)和线性瓶颈(Linear Bottlenecks)来设计网络,以提高模型的准确率,且优化后的模型更小。

Inverted residual block 结构是先使用 1x1 卷积进行升维,然后使用 3x3 的 Depthwise 卷积,最后使用 1x1 的卷积进行降维,与 Residual block 结构相反。Residual block 是先使用 1x1 的卷积进行降维,然后使用 3x3 的卷积,最后使用 1x1 的卷积进行升维。

说明:详细内容可参见 MobileNetV2 论文。

三、数据处理

数据准备

MobileNetV2 的代码默认使用 ImageFolder 格式管理数据集,每一类图片整理成单独的一个文件夹,数据集结构如下:

└─ImageFolder
  ├─train
  │   ├─class1Folder
  │   └─.....
  └─eval
      ├─class1Folder
      └─.....

数据加载

将模块导入,具体如下:

import math
import numpy as np
import os
import random

from matplotlib import pyplot as plt
from easydict import EasyDict
from PIL import Image
import mindspore.nn as nn
from mindspore import ops as P
from mindspore.ops import add
from mindspore import Tensor
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.vision as C
import mindspore.dataset.transforms as C2
import mindspore as ms
from mindspore import set_context, nn, Tensor, load_checkpoint, save_checkpoint, export
from mindspore.train import Model
from mindspore.train import Callback, LossMonitor, ModelCheckpoint, CheckpointConfig

配置后续训练、验证、推理用到的参数:

# 垃圾分类数据集标签,以及用于标签映射的字典。
garbage_classes = {
    '干垃圾': ['贝壳', '打火机', '旧镜子', '扫把', '陶瓷碗', '牙刷', '一次性筷子', '脏污衣服'],
    '可回收物': ['报纸', '玻璃制品', '篮球', '塑料瓶', '硬纸板', '玻璃瓶', '金属制品', '帽子', '易拉罐', '纸张'],
    '湿垃圾': ['菜叶', '橙皮', '蛋壳', '香蕉皮'],
    '有害垃圾': ['电池', '药片胶囊', '荧光灯', '油漆桶']
}

class_cn = ['贝壳', '打火机', '旧镜子', '扫把', '陶瓷碗', '牙刷', '一次性筷子', '脏污衣服',
            '报纸', '玻璃制品', '篮球', '塑料瓶', '硬纸板', '玻璃瓶', '金属制品', '帽子', '易拉罐', '纸张',
            '菜叶', '橙皮', '蛋壳', '香蕉皮',
            '电池', '药片胶囊', '荧光灯', '油漆桶']

class_en = ['Seashell', 'Lighter', 'Old Mirror', 'Broom', 'Ceramic Bowl', 'Toothbrush', 'Disposable Chopsticks', 'Dirty Cloth',
            'Newspaper', 'Glassware', 'Basketball', 'Plastic Bottle', 'Cardboard', 'Glass Bottle', 'Metalware', 'Hats', 'Cans', 'Paper',
            'Vegetable Leaf', 'Orange Peel', 'Eggshell', 'Banana Peel',
            'Battery', 'Tablet capsules', 'Fluorescent lamp', 'Paint bucket']

index_en = {'Seashell': 0, 'Lighter': 1, 'Old Mirror': 2, 'Broom': 3, 'Ceramic Bowl': 4, 'Toothbrush': 5, 'Disposable Chopsticks': 6, 'Dirty Cloth': 7,
            'Newspaper': 8, 'Glassware': 9, 'Basketball': 10, 'Plastic Bottle': 11, 'Cardboard': 12, 'Glass Bottle': 13, 'Metalware': 14, 'Hats': 15, 'Cans': 16, 'Paper': 17,
            'Vegetable Leaf': 18, 'Orange Peel': 19, 'Eggshell': 20, 'Banana Peel': 21,
            'Battery': 22, 'Tablet capsules': 23, 'Fluorescent lamp': 24, 'Paint bucket': 25}

# 训练超参
config = EasyDict({
    "num_classes": 26,
    "image_height": 224,
    "image_width": 224,
    "backbone_out_channels": 1280,
    "batch_size": 16,
    "eval_batch_size": 8,
    "epochs": 10,
    "lr_max": 0.05,
    "momentum": 0.9,
    "weight_decay": 1e-4,
    "save_ckpt_epochs": 1,
    "dataset_path": "./data_en",
    "class_index": index_en,
    "pretrained_ckpt": "./mobilenetV2-200_1067.ckpt"
})

数据预处理操作

利用 ImageFolderDataset 方法读取垃圾分类数据集,并整体对数据集进行处理。读取数据集时指定训练集和测试集,首先对整个数据集进行归一化,修改图像频道等预处理操作。然后对训练集的数据依次进行 RandomCropDecodeResizeRandomHorizontalFlipRandomColorAdjustshuffle 操作,以增加训练数据的丰富度;对测试集进行 DecodeResizeCenterCrop 等预处理操作;最后返回处理后的数据集。

def create_dataset(dataset_path, config, training=True, buffer_size=1000):
    """
    create a train or eval dataset

    Args:
        dataset_path(string): the path of dataset.
        config(struct): the config of train and eval in different platform.

    Returns:
        train_dataset, val_dataset
    """
    data_path = os.path.join(dataset_path, 'train' if training else 'test')
    ds = de.ImageFolderDataset(data_path, num_parallel_workers=4, class_indexing=config.class_index)
    resize_height = config.image_height
    resize_width = config.image_width

    normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])
    change_swap_op = C.HWC2CHW()
    type_cast_op = C2.TypeCast(mstype.int32)

    if training:
        crop_decode_resize = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
        horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)
        color_adjust = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)

        train_trans = [crop_decode_resize, horizontal_flip_op, color_adjust, normalize_op, change_swap_op]
        train_ds = ds.map(input_columns="image", operations=train_trans, num_parallel_workers=4)
        train_ds = train_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=4)

        train_ds = train_ds.shuffle(buffer_size=buffer_size

)
        train_ds = train_ds.batch(config.batch_size, drop_remainder=True)

        return train_ds
    else:
        decode_op = C.Decode()
        resize_op = C.Resize((256, 256))
        center_crop = C.CenterCrop(resize_height)

        eval_trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op]
        eval_ds = ds.map(input_columns="image", operations=eval_trans, num_parallel_workers=4)
        eval_ds = eval_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=4)

        eval_ds = eval_ds.batch(config.batch_size, drop_remainder=True)

        return eval_ds

四、模型构建与训练

网络结构

使用 MindSpore 框架中的 MobileNetV2 模型。可以使用 MindSpore 提供的预训练模型,具体代码如下:

from mindvision.classification.models import mobilenet_v2

def define_net(config, num_classes=26):
    """
    define MobilenetV2 network

    Args:
        config(struct): the config of train and eval in different platform.
        num_classes(int): the class number of dataset

    Returns:
        net
    """
    net = mobilenet_v2(num_classes=num_classes)

    if config.pretrained_ckpt:
        # Load the pre-trained checkpoint
        param_dict = load_checkpoint(config.pretrained_ckpt)
        load_param_into_net(net, param_dict)

    return net

优化器与损失函数

定义优化器和损失函数,使用交叉熵损失函数和动量优化器。具体代码如下:

def define_optimizer_and_loss(net, config):
    """
    define optimizer and loss function

    Args:
        net(Network): the neural network
        config(struct): the config of train and eval in different platform.

    Returns:
        optimizer, loss
    """
    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    lr = nn.dynamic_lr.cosine_decay_lr(0.01, config.lr_max, config.epochs, config.batch_size, decay_epoch=100)
    opt = nn.Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay)

    return opt, loss

模型训练

使用定义好的数据集、模型、优化器和损失函数进行训练,具体代码如下:

def train_model(config):
    """
    train the MobilenetV2 model

    Args:
        config(struct): the config of train and eval in different platform.

    Returns:
        None
    """
    train_ds = create_dataset(config.dataset_path, config, training=True)
    eval_ds = create_dataset(config.dataset_path, config, training=False)

    # Define the MobileNetV2 network
    net = define_net(config, config.num_classes)
    opt, loss = define_optimizer_and_loss(net, config)

    # Define the model
    model = Model(net, loss_fn=loss, optimizer=opt, metrics={'accuracy'})

    # Define the callbacks
    loss_cb = LossMonitor()
    ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_epochs * train_ds.get_dataset_size(), keep_checkpoint_max=5)
    ckpt_cb = ModelCheckpoint(prefix='mobilenetv2', config=ckpt_config)

    # Train the model
    model.train(config.epochs, train_ds, callbacks=[loss_cb, ckpt_cb], dataset_sink_mode=False)

    # Evaluate the model
    acc = model.eval(eval_ds, dataset_sink_mode=False)
    print("accuracy: ", acc)

保存模型

训练结束后,保存模型参数。可以使用 MindSpore 提供的 save_checkpoint 函数,具体代码如下:

def save_model(net, config):
    """
    save the MobilenetV2 model

    Args:
        net(Network): the neural network
        config(struct): the config of train and eval in different platform.

    Returns:
        None
    """
    save_checkpoint(net, os.path.join(config.dataset_path, 'mobilenetv2.ckpt'))

模型推理

加载保存的模型,进行模型推理,具体代码如下:

def infer_model(config, image_path):
    """
    infer the MobilenetV2 model

    Args:
        config(struct): the config of train and eval in different platform.
        image_path(string): the path of input image

    Returns:
        None
    """
    net = define_net(config, config.num_classes)
    param_dict = load_checkpoint(os.path.join(config.dataset_path, 'mobilenetv2.ckpt'))
    load_param_into_net(net, param_dict)

    # Preprocess the input image
    img = Image.open(image_path)
    img = img.resize((config.image_height, config.image_width))
    img = np.array(img).astype(np.float32)
    img = (img - [0.485*255, 0.456*255, 0.406*255]) / [0.229*255, 0.224*255, 0.225*255]
    img = img.transpose((2, 0, 1))
    img = img[np.newaxis, :]

    # Convert to Tensor
    input_tensor = Tensor(img, mstype.float32)

    # Infer the model
    net.set_train(False)
    output = net(input_tensor)
    pred = np.argmax(output.asnumpy(), axis=1)

    print("Predicted class: ", class_en[pred[0]])

五、总结

通过本次实验,了解了 MobileNetV2 模型的基本原理和结构,学习了数据预处理、模型训练、模型推理等操作步骤,掌握了 MindSpore 框架在移动设备端垃圾分类中的应用。实验过程涵盖了从数据准备、模型构建、模型训练到模型推理的全过程,对移动端轻量级模型的应用有了更深的理解和实践体验。

实验中遇到的问题及解决方法

  1. 数据集格式问题:初始数据集格式不符合要求,经过调整后使用 ImageFolder 格式管理数据集。
  2. 模型训练过程中内存不足:通过调整 batch size 和模型参数,降低内存占用。

通过这些问题的解决,进一步加深了对数据预处理和模型训练过程的理解,提高了处理实际问题的能力。

  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 以下是一个简单的微信小程序制作学习计划打卡记录页面的代码。 在 wxml 文件中,我们可以设置页面的布局和样式,以及与用户进行交互的组件,例如按钮、输入框等。 ```html <view class="container"> <view class="header"> <text class="title">学习计划打卡记录</text> </view> <view class="card"> <view class="card-heading"> <text class="card-title">Day 1</text> <text class="card-date">2021-01-01</text> </view> <view class="card-body"> <text class="card-label">学习时间:</text> <input class="card-input" type="number" placeholder="请输入学习时间(单位:小时)" /> </view> <view class="card-footer"> <button class="card-button">提交</button> </view> </view> <!-- Day 2、Day 3 等的卡片同上 --> </view> ``` 在 wxss 文件中,我们可以设置页面的样式,例如字体、颜色、边距等。 ```css .container { padding: 20rpx; } .header { text-align: center; margin-bottom: 20rpx; } .title { font-size: 36rpx; font-weight: bold; color: #333333; } .card { background-color: #ffffff; border-radius: 10rpx; box-shadow: 0 2rpx 6rpx rgba(0, 0, 0, 0.1); margin-bottom: 20rpx; } .card-heading { display: flex; justify-content: space-between; align-items: center; padding: 20rpx; border-bottom: 1rpx solid #f2f2f2; } .card-title { font-size: 28rpx; font-weight: bold; color: #333333; } .card-date { font-size: 24rpx; color: #999999; } .card-body { display: flex; align-items: center; padding: 20rpx; } .card-label { font-size: 28rpx; color: #333333; } .card-input { flex: 1; margin-left: 20rpx; font-size: 28rpx; color: #999999; border: none; } .card-footer { text-align: center; padding: 20rpx; } .card-button { font-size: 28rpx; color: #ffffff; background-color: #007aff; border-radius: 5rpx; padding: 10rpx 20rpx; } ``` 在 js 文件中,我们可以编写页面的逻辑代码。 ```javascript Page({ data: { days: [ { title: 'Day 1', date: '2021-01-01', time: null }, { title: 'Day 2', date: '2021-01-02', time: null }, { title: 'Day 3', date: '2021-01-03', time: null } // Day 4、Day 5 等的数据同上 ] }, handleInput: function(event) { // 获取输入框的值 let index = event.currentTarget.dataset.index; let value = event.detail.value; // 更新数据 let days = this.data.days; days[index].time = value; this.setData({ days: days }); }, handleSubmit: function(event) { // 获取提交按钮的索引 let index = event.currentTarget.dataset.index; // 获取对应的数据 let days = this.data.days; let day = days[index]; // 判断是否填写了学习时间 if (day.time === null || day.time === '') { wx.showToast({ title: '请填写学习时间', icon: 'none' }); return; } // 提交数据 wx.cloud.callFunction({ name: 'submit', data: { date: day.date, time: day.time }, success: res => { wx.showToast({ title: '提交成功' }); }, fail: err => { wx.showToast({ title: '提交失败', icon: 'none' }); } }); } }) ``` 以上代码是一个简单的微信小程序制作学习计划打卡记录页面的示例,仅供参考。具体实现方式可能因个人需求而异。 ### 回答2: 微信小程序制作学习计划打卡记录页面包含前几学习时间的全部代码如下: 首先,在小程序中创建一个page,命名为"studyRecord",在studyRecord.json文件中进行配置,设置"navigationBarTitleText"为"学习打卡",并设置"usingComponents"引入相关组件: ``` { "navigationBarTitleText": "学习打卡", "usingComponents": {} } ``` 接下来,在studyRecord.wxml文件中编写页面结构,包括一个日期选择器和一个列表用于展示打卡记录: ``` <view class="container"> <view class="header"> <picker mode="date" bindchange="dateChange"> <view class="date-picker">{{ currentDate }}</view> </picker> </view> <view class="record-list"> <block wx:for="{{ studyRecords }}" wx:key="index"> <view class="record-item"> <view class="item-date">{{ item.date }}</view> <view class="item-duration">{{ item.duration }}</view> </view> </block> </view> </view> ``` 我们在studyRecord.js文件中定义相关的事件处理函数和数据: ``` Page({ data: { currentDate: '', // 当前选择的日期 studyRecords: [] // 学习打卡记录 }, onLoad: function () { // 获取最近几学习打卡记录 this.getStudyRecords(); }, dateChange: function (event) { this.setData({ currentDate: event.detail.value }); // 根据选择日期的变化更新学习打卡记录 this.getStudyRecords(); }, getStudyRecords: function () { // 根据当前日期获取学习打卡记录,假设获取到的数据格式为[{ date: '2022/01/01', duration: '2小时' }, ...] // 可以通过调用接口或其他方式获取数据 const currentDate = this.data.currentDate; const studyRecords = this.getStudyRecordsByDate(currentDate); this.setData({ studyRecords: studyRecords }); }, getStudyRecordsByDate: function (date) { // 根据日期获取学习打卡记录的逻辑实现 // ... return studyRecords; // 返回按日期查询到的学习打卡记录 } }) ``` 在studyRecord.wxss文件中定义样式: ``` .container { padding: 10px; } .header { margin-bottom: 10px; } .date-picker { font-size: 18px; color: #333; padding: 10px; background-color: #f5f5f5; border-radius: 4px; text-align: center; } .record-list { background-color: #fff; border-radius: 4px; } .record-item { padding: 10px; border-bottom: solid 1px #eee; } .item-date { font-size: 14px; color: #666; } .item-duration { font-size: 16px; color: #333; } ``` 这样,一个包含前几学习时间的微信小程序制作学习计划打卡记录页面的代码就完成了。 ### 回答3: 要制作微信小程序的学习计划打卡记录页面,可以按照以下步骤进行: 1. 首先,需要在微信开发者工具中创建一个新的小程序项目,并在app.json文件中配置页面路由信息。 2. 在项目的根目录下创建一个新的文件夹,用于存放页面相关的文件,比如study-record文件夹。 3. 在study-record文件夹中创建一个study-record.wxml文件用于编写页面的结构。 4. 在study-record文件夹中创建一个study-record.wxss文件用于编写页面的样式。 5. 在study-record文件夹中创建一个study-record.js文件用于编写页面的逻辑代码。 6. 在study-record.js中定义一个数据对象,用于存储前几学习时间。可以使用数组来存储每一学习时间,比如每个元素都是一个包含日期和学习时间的对象。 7. 在study-record.js中编写一个函数来获取前几学习时间。可以使用Date对象和相关的方法来计算前几的日期,然后根据日期从数据对象中获取对应的学习时间。 8. 在study-record.js中编写一个函数来更新学习时间。可以通过用户输入的方式来更新某一学习时间,并将更新后的数据保存到数据对象中。 9. 在study-record.wxml中使用wx:for循环来遍历数据对象中的学习时间,并将日期和学习时间显示在页面上。 10. 在study-record.wxml中添加一个按钮,用于触发更新学习时间的函数。 11. 在study-record.js中监听按钮的点击事件,并在点击时触发更新学习时间的函数。 12. 在study-record.wxss中设置页面的样式,比如学习时间的字体大小、颜色等。 通过以上步骤,就可以完成微信小程序的学习计划打卡记录页面的制作。在页面中包含了前几学习时间,并提供了更新学习时间的功能。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值