比赛地址:植物种子分类
注意的点:
-
在网络中,一般训练过程中设置
shuffle=True
,在测试集中设置shuffle=false。
-
使用
datasets.ImageFolder
读取数据,并且制作数据集。分类任务与图像分割任务不同。分类任务的数据是:【图片,标签(字符串类型)】,所以两者的数据读取方式不同。在分割任务中,常常需要重写Dataset
便于图像预处理,而在该分类任务中,不需要重写Dataset
,在datasets.ImageFolder
中,可以接收transform
参数对读入的图像进行处理,而不对标签(字符串)处理,且会将标签自动转为标签索引形式。关于datasets.ImageFolde -
torch的
Dataloader
接受的是(data, labels)
的元组形式,在 PyTorch 的 DataLoader 中,元组列表中元素的数据类型要求相对较松。每个元组的第一个元素通常是输入数据,第二个元素是对应的标签。这两个元素可以是任何 PyTorch 支持的数据类型,例如张量(torch.Tensor)、NumPy 数组、PIL 图像等。 -
对于使用预训练好的Resnet-18,可以通过更改网络最后一层,来适应该分类任务。对于很多模型,model.fc 是最后一层的全连接层。
-
在这个比赛中,最初得分总是很低。最后发现原因是:在提交submission中,图片名称是按照顺序读入的,但是在使用Dataloader读入测试集数据时,使用了
shuffle=True
,导致读入的顺序被打乱,从而使得图片名称和预测标签不对应,导致得分很低。改为shuffle=Flase
问题解决。
代码,按照ipynb顺序排列:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
# for filename in filenames:
# print(os.path.join(dirname, filename))
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session
import numpy as np
import os
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from torch import optim
from torch import nn
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.data import random_split
from tqdm import tqdm
import imageio
from torchvision import datasets
from PIL import Image # Image模块是在Python PIL图像处理中常见的模块,对图像进行基础操作的功能基本都包含于此模块内。
work_dir = '/kaggle/input/plant-seedlings-classification'
os.listdir(work_dir)
`import glob
#读取数据,用于后续制作数据集
train_path = os.path.join(work_dir,'train')
# 使用glob列出train文件夹下的所有文件夹
folders = glob.glob(os.path.join(train_path, '*'))
print(f'总的类别数量:{
len(folders)}')``
```python
# values from ImageNet, recommended by PyTorch
transform_mean = [0.485, 0.456, 0.406]
transform_std = [0.229, 0.224, 0.225]
transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=transform_mean, std=transform_std),
])
dataset = datasets.ImageFolder(root=train_path,transform=transforms)
# self.classes:用一个 list 保存类别名称
# self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
# self.imgs:保存(img-path, class) tuple的 list
#查看有多少个样例和多少个类别
print('samples',len