mindspore实现自定义CNN图像分类模型

一、数据集定义

         使用mindspore.dataset中的ImageFolderDataset接口加载图像分类数据集,ImageFolderDataset接口传入数据集文件上层目录,每个子目录分别放入不同类别的图像。使用python定义一个create_dataset函数用于创建数据集,在函数中使用mindspore.dataset.vision接口中的Decode、Resize、Normalize、HWC2CHW对图像进行解码、调整尺寸、归一化和通道变换预处理,其中Resize根据模型需要的图像大小进行设置,归一化操作可以通过设置mean和std约束范围。如将mean设置为[127.5,127.5,127.5],std设置为[255,255,255],可以将数据归一化到[0.5~0.5]范围内。

数据集加载:

import mindspore
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import context, save_checkpoint, ops, Tensor
import mindspore.dataset as ds
import mindspore.dataset.vision as CV
import mindspore.dataset.transforms as C
from mindspore import dtype as mstype


def create_dataset(data_path, batch_size=24, repeat_num=1):
    """定义数据集"""
    data_set = ds.ImageFolderDataset(data_path, num_parallel_workers=8, shuffle=True)
    image_size = [100, 100]
    mean = [127.5, 127.5, 127.5]
    std = [255., 255., 255.]
    trans = [
        CV.Decode(),
        CV.Resize(image_size),
        CV.Normalize(mean=mean, std=std),
        CV.HWC2CHW()
    ]
    # 实现数据的map映射、批量处理和数据重复的操作
    type_cast_op = C.TypeCast(mstype.int32)
    data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
    data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
    data_set = data_set.batch(batch_size, drop_remainder=True)
    data_set = data_set.repeat(repeat_num)
    return data_set

二、定义网络结构

         定义神经网络需要使用mindspore.nn模块,使用python创建一个cnn_net类并继承nn.Cell,在init中初始化模型需要用到的各种算子,该卷积神经网络需要用到的算子分别为卷积层nn.Conv2d、激活函数nn.Relu、池化层nn.Maxpool2d、打平操作nn.Flatten、全连接层nn.Dense。这里用的自定义卷积神经网络由4层卷积+2层全连接组成,每个卷积层后接一个激活函数和最大池化层,每个池化层通过设置步长为2对特征图进行尺寸减半,因此在经过四层卷积后特征图变为输入的1/16,也就是6*6。在卷积层后接一个打平操作,将特征图从二维转换为一维,特征图打平以才能后进入全连接层,最后一层全连接层输出通道数与分类类别数一致。模型中每层输入输出通道定义如下:

卷积层1:输入通道3,输出通道8,卷积核3*3

卷积层2:输入通道8,输出通道16,卷积核3*3

卷积层3:输入通道16,输出通道32,卷积核3*3

卷积层4:输入通道32,输出通道64,卷积核3*3

全连接层1:输入288,输出128

全连接层2:输入128,输出分类数

网络实现:

class cnn_net(nn.Cell):
    """
    网络结构
    """
    def __init__(self, num_class=10, num_channel=3):
        super(cnn_net, self).__init__()
        # 定义所需要的运算
        self.conv1 = nn.Conv2d(in_channels=num_channel, out_channels=8, kernel_size=3)
        self.conv2 = nn.Conv2d(8, 16, 3)
        self.conv3 = nn.Conv2d(16, 32, 3)
        self.conv4 = nn.Conv2d(32, 64, 3)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Dense(2304, 128, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(128, num_class, weight_init=Normal(0.02))

    def construct(self, x):
        # 使用定义好的运算构建前向网络
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv4(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

三、定义损失函数计算流程

         由于网络中没有带损失函数,需要单独定义一个类NetWithLoss用于计算损失,在计算损失前,需要将标签进行onehot编码,如分类标签为[0,1,2],那么标签1经过onehot转换为[0,1,0],之后将数据送入模型进行前向计算,得到logits,使用交叉熵损失函数对logits和label计算损失。

class NetWithLoss(nn.Cell):
    def __init__(self, backbone, loss_func, classes):
        super(NetWithLoss, self).__init__()
        self.backbone = backbone
        self.loss_func = loss_func
        self.classes = classes

    def construct(self, inputs, labels):
        labels = ops.one_hot(labels, self.classes,
                        Tensor(1, dtype=mindspore.float32),
                        Tensor(0, dtype=mindspore.float32))
        logits = self.backbone(inputs)
        loss = self.loss_func(logits, labels)
        return ops.mean(loss, axis=0)

四、定义训练流程

         定义一个train函数进行训练,在训练函数中首先定义迭代次数,学习率,批大小,分类数量、输入通道、训练集、验证集、模型、损失函数、优化器等,这里使用for循环进行训练迭代,在数据集迭代过程中使用nn.TrainOneStepCell进行模型训练。在每一轮训练结束后对模型进行验证,计算模型推理准确率。

         在开启训练之前可以通过设置运行环境来觉得模型在什么设备上运行。mindspore支持CPU、GPU、以及Ascend(昇腾训练加速卡),当然,不同设备需要安装对应版本的mindspore。

def train():
    # 数据路径
    epochs = 10
    lr = 0.001
    batch_size = 32
    num_classes = 2
    input_channel = 3
    ckpt_file = 'best.ckpt'
    train_data_path = "./datasets/dogs/train"
    eval_data_path = "./datasets/dogs/val"
    train_ds = create_dataset(train_data_path, batch_size)
    eval_ds = create_dataset(eval_data_path, 1)
    eval_ds_size = eval_ds.get_dataset_size()
    net = cnn_net(num_classes, input_channel)
    opt = nn.Adam(params=net.trainable_params(), learning_rate=lr)
    loss_func = nn.SoftmaxCrossEntropyWithLogits()
    loss_net = NetWithLoss(net, loss_func, num_classes)
    train_net = nn.TrainOneStepCell(loss_net, opt)
    train_net.set_train()
    argmax = ops.Argmax(axis=0)
    best_acc = 0
    best_epoch = 0
    for epoch in range(epochs):
        train_loss = 0
        # 训练
        for data in train_ds.create_tuple_iterator():
            images = data[0]
            lables = data[1]
            loss = train_net(images, lables)
            train_loss += loss

        # 评估
        total = 0
        for data in eval_ds.create_tuple_iterator():
            images = data[0]
            lables = data[1].squeeze()
            logits = net(images)
            pred = argmax(logits.squeeze())
            if pred == lables:
                total += 1

        acc = total / eval_ds_size
        # 保存ckpt
        if acc > best_acc:
            best_acc = acc
            best_epoch = epoch + 1
            save_checkpoint(net, ckpt_file)
        ckpt_file = f'epoch{epoch+1}.ckpt'
        save_checkpoint(net, ckpt_file)
        print(f'epoch:{epoch+1}, loss:{train_loss}, acc:{acc}')
    print(f'train success, best epoch is {best_epoch}, best acc is {best_acc}')


if __name__ == '__main__':
    train()

 

  • 4
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

TheMatrixs

你的鼓励将是我创作的最大动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值