【kaggle代码】Plant Seedlings Classification (使用Resnet-18完成分类任务)

本文讲述了在Kaggle植物种子分类比赛中,如何解决由于数据加载时shuffle=True导致得分过低的问题,通过调整测试集数据加载的shuffle选项,确保图片名称与预测标签对应,从而提高模型性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

比赛地址:植物种子分类

注意的点:

  1. 在网络中,一般训练过程中设置shuffle=True,在测试集中设置shuffle=false。

  2. 使用datasets.ImageFolder读取数据,并且制作数据集。分类任务与图像分割任务不同。分类任务的数据是:【图片,标签(字符串类型)】,所以两者的数据读取方式不同。在分割任务中,常常需要重写Dataset便于图像预处理,而在该分类任务中,不需要重写Dataset,在datasets.ImageFolder中,可以接收transform参数对读入的图像进行处理,而不对标签(字符串)处理,且会将标签自动转为标签索引形式。关于datasets.ImageFolde

  3. torch的Dataloader接受的是(data, labels)的元组形式,在 PyTorch 的 DataLoader 中,元组列表中元素的数据类型要求相对较松。每个元组的第一个元素通常是输入数据,第二个元素是对应的标签。这两个元素可以是任何 PyTorch 支持的数据类型,例如张量(torch.Tensor)、NumPy 数组、PIL 图像等。

  4. 对于使用预训练好的Resnet-18,可以通过更改网络最后一层,来适应该分类任务。对于很多模型,model.fc 是最后一层的全连接层。
    在这里插入图片描述

  5. 在这个比赛中,最初得分总是很低。最后发现原因是:在提交submission中,图片名称是按照顺序读入的,但是在使用Dataloader读入测试集数据时,使用了shuffle=True,导致读入的顺序被打乱,从而使得图片名称和预测标签不对应,导致得分很低。改为shuffle=Flase问题解决。

代码,按照ipynb顺序排列:

# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session
import numpy as np
import os
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from torch import optim
from torch import nn
import cv2 
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.data import random_split
from tqdm import tqdm
import imageio
from torchvision import datasets
from PIL import Image # Image模块是在Python PIL图像处理中常见的模块,对图像进行基础操作的功能基本都包含于此模块内。
work_dir = '/kaggle/input/plant-seedlings-classification'
os.listdir(work_dir)
`import glob
#读取数据,用于后续制作数据集
train_path = os.path.join(work_dir,'train')

# 使用glob列出train文件夹下的所有文件夹
folders = glob.glob(os.path.join(train_path, '*'))

print(f'总的类别数量:{
     len(folders)}')``

```python
# values from ImageNet, recommended by PyTorch
transform_mean = [0.485, 0.456, 0.406]
transform_std = [0.229, 0.224, 0.225]

transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=transform_mean, std=transform_std),
])

dataset = datasets.ImageFolder(root=train_path,transform=transforms)

# self.classes:用一个 list 保存类别名称
# self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
# self.imgs:保存(img-path, class) tuple的 list
#查看有多少个样例和多少个类别
print('samples',len
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

超好的小白

没体验过打赏,能让我体验一次吗

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值