RMB分类模型实现
数据集不一定必须是RMB数据集,其他数据集同理,注意读取路径的设置
数据读取
数据收集 ------ Img, label
数据划分 ------ train, valid, test
数据读取 ------ DataLoader ------Sampler 生成索引,即样本的序号 DataSet 根据索引读取图片和标签
数据预处理 ----- transforms
DataLoader 与 DataSet
torch.utils.data.DataLoader
功能 : 构建可迭代的数据装载器
dataset: Dataset类, 决定数据从哪读取及如何读取
batchsize: 批大小
num_works: 是否多进程读取数据
shuffle: 每个epoch是否乱序
drop_last: 当样本数不能被batchsize整除时, 是否舍弃最后一批数据
# DataLoader(dataset,
# batch_size=1,
# shuffle=False,
# sampler=None,
# batch_sampler=None,
# num_workers=0,
# collate_fn=None,
# pin_memory=False,
# drop_last=False,
# timeout=0,
# worker_int_fn=None,
# multiprocessing_context=None)
torch.utils.data.Dataset
功能:Dataset 抽象类,所有自定义的Dataset需要继承它,并且复写
# # 功能:Dataset 抽象类,所有自定义的Dataset需要继承它,并且复写
# # __getitem__()
# # getitem: 接受一个索引,返回一个样本及标签
# # class Dataset(object):
# # def __getitem__(self, index):
# # raise NotImplementedError
# #
# # def __add__(self, other):
# # return ConcatDataset([self, other])
Epoch: 所有训练样本都已输入到模型中, 称为一个Epoch;
Iteration: 一批样本输入到模型中, 称为为一个Iteration;
Batchsize: 批大小,决定一个Epoch有多少个Iteration;
读哪些数据? Sample输出的Index
从哪读数据? Dataset中的data_dir
怎么读数据? Dataset中的getitem
数据读取具体实现
代码为深度之眼课程内容,在其中加入了比较详细的注释便于理解。
第一步,分割数据集,1_split_dataset.py
# -*- coding: utf-8 -*-
"""
# @file name : 1_split_dataset.py
# @author : tingsongyu
# @date : 2019-09-07 10:08:00
# @brief : 将数据集划分为训练集,验证集,测试集
"""
import os
import random
import shutil
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir)
if __name__ == '__main__':
random.seed(1)
dataset_dir = os.path.join("..", "..", "data", "RMB_data")
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
test_dir = os.path.join(split_dir, "test")
train_pct = 0.8
valid_pct = 0.1
test_pct = 0.1
for root, dirs, files in os.walk(dataset_dir):
for sub_dir in dirs:
imgs = os.listdir(os.path.join(root, sub_dir))
imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
random.shuffle(imgs)
img_count = len(imgs)
train_point = int(img_count * train_pct)
valid_point = int(img_count * (train_pct + valid_pct))
for i in range(img_count):
if i < train_point:
out_dir = os.path.join(train_dir, sub_dir)
elif i < valid_point:
out_dir = os.path.join(valid_dir, sub_dir)
else:
out_dir = os.path.join(test_dir, sub_dir)
makedir(out_dir)
target_path = os.path.join(out_dir, imgs[i])
src_path = os.path.join(dataset_dir, sub_dir, imgs[i])
shutil.copy(src_path, target_path)
print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
img_count-valid_point))
第二步,训练数据,2_train_lenet.py
重点是数据读取部分
# -*- coding: utf-8 -*-
"""
# @file name : 2_train_lenet.py
# @author : tingsongyu
# @date : 2019-09-07 10:08:00
# @brief : 人民币分类模型训练
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
def set_seed(seed=1):
random.seed(seed)
np.random