pytorch——用resnet18做动物多分类问题(含可视化结果)

代码以及数据集下载:https://github.com/duchp/python-all/tree/master/CV%20code/动物多分类项目

一、任务介绍

  1. 纲分类问题,预测该动物是属于哺乳纲(Mammals)还是鸟纲(Birds)
  2. 种分类问题,预测该动物是兔子、老鼠还是鸡
  3. 多任务分类问题,同时预测动物的“纲”和“种”

二、数据预处理

使用的数据集结构如下:

数据集结构
上图为文件夹数据结构,同时,我们对于每一个任务,都有两个.csv文件,用于存储对应的训练数据与测试数据的标签。

训练的第一步是将数据存为pytorch里面的Dataset类,便于后续处理和读取数据。

因为三个任务的数据一致,不同的是标签情况,在代码上差异不大,所以以第一个任务——纲分类为例,数据存为Dataset类的代码及注释如下:

class MyDataset(torch.utils.data.Dataset):
    def __init__(self,root,transform=None):
        super(MyDataset,self).__init__()
        #读取csv文件
        file_info = pd.read_csv(root, index_col=0)
        #读取图像路径
        file_path = file_info['path']
        #读取图像纲分类标签
        file_class = file_info['classes']
        imgs = []
        imglb = []
        #依次处理数据
        for i in range(len(file_path)):
            path = file_path[i]
            path = path.replace('\\','/')
            if not os.path.isfile(path):
                print(path + '  does not exist!')
                return None
       #读对应路径的图像,存为img,依次存入imgs
            img = Image.open(path).convert('RGB')
            imgs.append(img)
       #读对应图像的标签,依次存入imglb  
            imglb.append(int(file_class[i]))
        self.image = imgs
        self.imglb = imglb
        self.root = root
        self.size = len(file_info)
        self.transform = transform
        
    def __getitem__(self,index):
        img = self.image[index]
        label = self.imglb[index]
        sample = {'image': img,'classes':label}
        if self.transform:
            sample['image'] = self.transform(img)
        return sample
        
    def __len__(self):
        return self.size

存好数据后,开始为训练网络做准备,在深度学习训练中,一般会对数据进行一些变换,比如:翻转,旋转等等,用来增强鲁棒性,因此,这里需要定义一个数据变换函数:

train_transforms = transforms.Compose([transforms.Resize((500, 500)),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       ])
val_transforms = transforms.Compose([transforms.Resize((500, 500)),
                                     transforms.ToTensor()
                                     ])

这里我们只做了训练图像的resize和随机水平翻转变换,实际上图像变换还有很多选择,具体可查询torchvision.transforms.Compose这个函数的参数设置。这个设定是要赋值到上面的MyDataset里面的,可以看到上面的transform默认为None,如果不赋值,则不做变换。

到了这一步,我们的数据还需要一个非常重要的处理,那就是分batch,因为如果一次性把数据放入训练中,往往内存是不够的,如果一张一张的训练,每次修正方向以

  • 24
    点赞
  • 158
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
动物分类是一个常见的图像分类问题,可以使用卷积神经网络(CNN)来解决。以下是一个基于 MATLAB 的动物分类示例: 1. 数据集准备:准备一个包不同类别动物图像的数据集。可以使用公开的数据集如 ImageNet 或自己收集数据集。 2. 数据预处理:将每个图像缩放为相同的大小,并进行数据增强,以增加模型的泛能力。 3. 模型设计:使用 MATLAB 的深度学习工具箱中的卷积神经网络函数设计模型,例如 AlexNet、VGGNet 或 ResNet。 4. 模型训练:使用训练图像和标签训练模型,并使用验证集评估模型的性能。 5. 模型测试:使用测试图像测试模型的性能,并计算分类准确率。 以下是一个简单的动物分类示例代码: ```matlab % 加载数据集 imds = imageDatastore('path to dataset', 'IncludeSubfolders', true, 'LabelSource', 'foldernames'); % 数据预处理 inputSize = [224 224]; augmenter = imageDataAugmenter('RandRotation',[-20,20],'RandXReflection',true,'RandYReflection',true); auimds = augmentedImageDatastore(inputSize, imds, 'DataAugmentation', augmenter); % 模型设计 net = alexnet; % 修改全连接层 numClasses = numel(categories(imds.Labels)); layers = net.Layers; fc = fullyConnectedLayer(numClasses, 'Name','fc8'); layers(end-2) = fc; layers(end) = classificationLayer; % 设置训练选项 options = trainingOptions('sgdm', ... 'MiniBatchSize', 10, ... 'MaxEpochs', 20, ... 'InitialLearnRate', 1e-4, ... 'Verbose', true, ... 'ValidationData', auimds, ... 'Plots', 'training-progress'); % 训练模型 net = trainNetwork(auimds, layers, options); % 测试模型 testimds = imageDatastore('path to test dataset', 'IncludeSubfolders', true, 'LabelSource', 'foldernames'); testaugimds = augmentedImageDatastore(inputSize, testimds); YPred = classify(net, testaugimds); YTest = testimds.Labels; accuracy = sum(YPred == YTest)/numel(YTest) ``` 这个示例使用 AlexNet 模型和随机旋转和镜像的数据增强。可以根据实际情况进行调整和修改。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值