用Pytorch实现CIFAR10数据集分类(持续修改中)

1.导包

"""导包"""
import collections
import numpy as np
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
import matplotlib.pyplot as plt
from torch import nn
import torch .nn.functional as F

2.整理数据集

"""整理数据集"""

def read_csv_labels(fname):
    """读取 fname 来给标签字典返回一个文件名"""
    with open(fname, 'r') as f:
        # 跳过文件头行 (列名)
        lines = f.readlines()[1:] #将文件按行读取,返回值为列表
    tokens = [l.rstrip().split(',') for l in lines] #将每一行用逗号分隔开来,并去掉多余的空格
    return dict(((name, label) for name, label in tokens)) #将token里面的信息遍历出来,并赋值给name和label并做成字典

labels = read_csv_labels('../CIFAR10 Image-Classification/trainLabels.csv')
for item in labels.items():
    print(item)
print('类别 :', len(set(labels.values())))

整理好后的样子
在这里插入图片描述

3.将验证集从原始的数据集中拆分出来

"""将验证集从原始的数据集中拆分出来"""
data_dir=('../CIFAR10 Image-Classification/')

def copyfile(filename, target_dir):
    """将文件复制到目标目录"""
    os.makedirs(target_dir, exist_ok=True) #创建多层目录,目录名为target—dir
    shutil.copy(filename, target_dir) #将filename的文件复制到target—dir当中
    
def reorg_train_valid(data_dir, labels, valid_ratio):
    # 训练数据集中示例最少的类别中的示例数
    n = collections.Counter(labels.values()).most_common()[-1][1]
    # 验证集中每个类别的示例数
    n_valid_per_label = max(1, math.floor(n * valid_ratio)) #示例数与1比大小,n为训练数据集中示例最少的类别中的示例数,valid_ratio 是验证集中的示例数与原始训练集中的示例数之比
    label_count = {
   }
    for train_file in os.listdir(os.path.join(data_dir,'train')): #列出train 路径下所有的文件并遍历给train_file,该数据集是打乱的
        label = labels[train_file.split('.')[0]] #用.将train_file切片,切片后只有两行,第一行为类别标号,第二行为类别名称,我们只取第一行赋值给label
        
        fname = os.path.join(data_dir, 'train', train_file)  #fname=data_dir/train/train_file,train_file有label和名称组成
        copyfile(fname, os.path.join(data_dir, 'train_valid_test',
                                     'train_valid', label)) #将data_dir/train/train_file的文件复制到data_dir/train_valid_test/train_valid/label,原始的cifar10数据未经过划分
        if label not in label_count or label_count[label] < n_valid_per_label: 
            copyfile(fname, os.path.join(data_dir, 'train_valid_test',
                                         'valid', label))  #将数据添加到验证集
            label_count[label] = label_count.get(label, 0) + 1 #返回字典label_count中label元素所对应的值,label_count[label]给字典label——count中元素label赋值为label值+1
        else:
            copyfile(fname, os.path.join(data_dir, 'train_valid_test',
                                         'train', label)) #将数据添加到训练集
    return n_valid_per_label

4.在预测期间整理测试集,以方便读取

def reorg_test(data_dir):
    for test_file in os.listdir(os.path.join(data_dir, 'test')): #将测试集的数据全部列出,然后遍历出来
        copyfile(os.path.join(data_dir, 'test', test_file),
                 os.path.join(data_dir, 'train_valid_test', 'test',
                              'unknown')) #将data_dir/test/test_file文件夹复制到data_dir/train_valid_test/test,标签放到unknown下面

5.调用前面定义的函数

def reorg_cifar10_data(data_dir, valid_ratio):
    labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv')) #读取data_dir/trainLabels.csv中的name和label以字典形式输出
    reorg_train_valid(data_dir, labels, valid_ratio)
    reorg_test(data_dir)
    
# batch_size = 32
valid_ratio = 0.1 #验证集与训练集为1比9
reorg_cifar10_data(data_dir, valid_ratio) ##输出一个data_dir文件,里面有划分好i的test,train,valid

6.图像增广

transform_train = torchvision.transforms.Compose([
    # 在高度和宽度上将图像放大到40像素的正方形     图片本身是32*32太小了
    torchvision.transforms.Resize(40),
    # 随机裁剪出一个高度和宽度均为40像素的正方形图像,
    # 生成一个面积为原始图像面积0.64到1倍的小正方形,
    # 然后将其缩放为高度和宽度均为32像素的正方形
    torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), #裁剪的图里面有高和宽至少是原来的64%
                                                   ratio=(1.0, 1.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    # 标准化图像的每个通道
    torchvision.transforms.Normalize([0.4914
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Tomorrow;

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值