pytorch入门学习 (七)------ RMB分类模型实现之数据读取

RMB分类模型实现

数据集不一定必须是RMB数据集,其他数据集同理,注意读取路径的设置

数据读取

数据收集 ------ Img, label

数据划分 ------ train, valid, test

数据读取 ------ DataLoader ------Sampler 生成索引,即样本的序号 DataSet 根据索引读取图片和标签

数据预处理 ----- transforms

DataLoader 与 DataSet

torch.utils.data.DataLoader

功能 : 构建可迭代的数据装载器
dataset: Dataset类, 决定数据从哪读取及如何读取
batchsize: 批大小
num_works: 是否多进程读取数据
shuffle: 每个epoch是否乱序
drop_last: 当样本数不能被batchsize整除时, 是否舍弃最后一批数据

# DataLoader(dataset,
#            batch_size=1,
#            shuffle=False,
#            sampler=None,
#            batch_sampler=None,
#            num_workers=0,
#            collate_fn=None,
#            pin_memory=False,
#            drop_last=False,
#            timeout=0,
#            worker_int_fn=None,
#            multiprocessing_context=None)
torch.utils.data.Dataset

功能:Dataset 抽象类,所有自定义的Dataset需要继承它,并且复写

# # 功能:Dataset 抽象类,所有自定义的Dataset需要继承它,并且复写
# # __getitem__()
# # getitem: 接受一个索引,返回一个样本及标签
# # class Dataset(object):
# #     def __getitem__(self, index):
# #         raise NotImplementedError
# #
# #     def __add__(self, other):
# #         return ConcatDataset([self, other])

Epoch: 所有训练样本都已输入到模型中, 称为一个Epoch;
Iteration: 一批样本输入到模型中, 称为为一个Iteration;
Batchsize: 批大小,决定一个Epoch有多少个Iteration;

读哪些数据? Sample输出的Index

从哪读数据? Dataset中的data_dir

怎么读数据? Dataset中的getitem

数据读取具体实现

代码为深度之眼课程内容,在其中加入了比较详细的注释便于理解。

第一步,分割数据集,1_split_dataset.py
# -*- coding: utf-8 -*-
"""
# @file name  : 1_split_dataset.py
# @author     : tingsongyu
# @date       : 2019-09-07 10:08:00
# @brief      : 将数据集划分为训练集,验证集,测试集
"""

import os
import random
import shutil


def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)


if __name__ == '__main__':

    random.seed(1)

    dataset_dir = os.path.join("..", "..", "data", "RMB_data")
    split_dir = os.path.join("..", "..", "data", "rmb_split")
    train_dir = os.path.join(split_dir, "train")
    valid_dir = os.path.join(split_dir, "valid")
    test_dir = os.path.join(split_dir, "test")

    train_pct = 0.8
    valid_pct = 0.1
    test_pct = 0.1

    for root, dirs, files in os.walk(dataset_dir):
        for sub_dir in dirs:

            imgs = os.listdir(os.path.join(root, sub_dir))
            imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
            random.shuffle(imgs)
            img_count = len(imgs)

            train_point = int(img_count * train_pct)
            valid_point = int(img_count * (train_pct + valid_pct))

            for i in range(img_count):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sub_dir)
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sub_dir)
                else:
                    out_dir = os.path.join(test_dir, sub_dir)

                makedir(out_dir)

                target_path = os.path.join(out_dir, imgs[i])
                src_path = os.path.join(dataset_dir, sub_dir, imgs[i])

                shutil.copy(src_path, target_path)

            print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
                                                                 img_count-valid_point))

第二步,训练数据,2_train_lenet.py

重点是数据读取部分

# -*- coding: utf-8 -*-
"""
# @file name  : 2_train_lenet.py
# @author     : tingsongyu
# @date       : 2019-09-07 10:08:00
# @brief      : 人民币分类模型训练
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed
  • 9
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值