pytorch(07)数据模型的读取

DataLoader与Dataset

pytorch中的数据读取机制

graph TB DataLoader --> DataLoaderIter DataLoaderIter --> Sampler Sampler --> Index Sampler --> DatasetFetcher Index -->DatasetFetcher DatasetFetcher -->Dataset Dataset --> getitem getitem -->img,label img,label --> collate_fn collate_fn --> BatchData
  1. 人民币二分类
    可以把人民币当成自变量x,类别是y。
    数据模块可以分为
  2. 数据收集->原始样本和标签,img,label
  3. 数据划分->划分train,valid,test。验证集来调整过拟合
  4. 数据读取->数据读取,DataLoader
    DataLoader分为两个子模块,分别是
  • Sampler生成索引,样本的序号index
  • DataSet根据索引,读取img和label
  1. 数据预处理->transforms
  2. DataLoader与Dataset
    DataLoader和Dataset是数据读取的核心
  3. DataLoader
    DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,work_init_fn=None,multiprocessing_context=None)
    主要是构建可迭代的数据转载器
    dataloader,我们在训练的时候在每一次循环中,就是从dataset中读取每一个batch_size大小的数据
  • dataset:Dataset类,决定数据从哪读取及如何读取
  • batchsize:批大小
  • num_works:是否多进程读取数据
  • shuffle:每个epoch是否乱序
  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
    epoch,iteration,batchsize
  • Epoch:所有训练样本都已输入到模型中,称为一个Epoch
  • Iteration:一批样本输入到模型中,称为一个Iteration
  • Batchsize:批大小,决定一个Epoch有多少个Iteration
    样本总数:80,BatchSize:8
    1 Epoch = 10 Iteration
    如果样本总数不能被整除
    样本总数:87,Batchsize:8
  • 1 Epoch = 10 Iteration,drop_last=True
  • 1 Epoch = 11 Iteration,drop_last=False
  1. Dataset
    torch.utils.data.Dataset
    class Dataset(object):
    def getitem(self,index):
    ​ raise NotImplementedError
    def add(self,other):
    ​ return ConcatDataset([self,other])
    功能:Dataset抽象类,所有自定义的Dataset需要继承,并复写
    __getitem__()
    getitem: 接收一个索引,返回一个样本

数据读取机制

  1. 读哪些数据,在每一个iteration中读取哪些数据?
  2. 从哪读数据,在硬盘中如何读取?
  3. 怎么读数据?
import os
import random
import shutil
BASE_DIR = os.path.dirname(os.path.abspath(__file__))


def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)


if __name__ == '__main__':

    DATA_DIR = os.path.abspath(os.path.join(BASE_DIR, ".", "RMB_data"))
    SPLIT_DIR = os.path.abspath(os.path.join(BASE_DIR, ".", "rmb_split"))
    TRAIN_DIR = os.path.join(SPLIT_DIR, "train")
    VALID_DIR = os.path.join(SPLIT_DIR, "valid")
    TEST_DIR = os.path.join(SPLIT_DIR, "test")

    if not os.path.exists(DATA_DIR):
        raise Exception("\n{}不存在,请下载RMBdata放到{}路径下".format(DATA_DIR, os.path.dirname(DATA_DIR)))

    train_pct = 0.8
    valid_pct = 0.1
    test_pct = 0.1

    for paths, dirs, files in os.walk(DATA_DIR):
        for sub_dirs in dirs:
            imgs = os.listdir(os.path.join(paths, sub_dirs))
            imgs = list(filter(lambda x: x.endswith('.jpg'),imgs))
            # print(imgs)
            random.shuffle(imgs)
            # print(imgs)
            imgs_count = len(imgs)
            # print(imgs_count)

            train_pic = int(train_pct*imgs_count)
            valid_pic = int((valid_pct+train_pct)*imgs_count)

            if imgs_count == 0 :
                print("{}目录下,无图片,请检查".format(os.path.join(paths, sub_dirs)))
                import sys
                sys.exit(0)

            for i in range(imgs_count):
                if i < train_pic :
                    out_dir = os.path.join(TRAIN_DIR, sub_dirs)
                elif i < valid_pic :
                    out_dir = os.path.join(VALID_DIR, sub_dirs)
                else:
                    out_dir = os.path.join(TEST_DIR, sub_dirs)

                makedir(out_dir)

                target_path = os.path.join(out_dir, imgs[i])
                src_path = os.path.join(DATA_DIR, sub_dirs, imgs[i])

                shutil.copy(src_path, target_path)

            print("Class:{}, train:{}, valid:{}, test:{}".format(sub_dirs, train_pic, valid_pic-train_pic, imgs_count-valid_pic-train_pic))
            print("已在{}划分好".format(out_dir)
Class:1, train:80, valid:10, test:-70
已在D:\pythonProject\04_DataLoader\rmb_split\test\1划分好
Class:100, train:80, valid:10, test:-70
已在D:\pythonProject\04_DataLoader\rmb_split\test\100划分好
import numpy as np
import torch
import os
import random
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
BASE_PATH = os.path.abspath(__file__)
# print(BASE_PATH)
base_path = os.path.abspath(os.path.join(BASE_PATH, '..', 'TestDir'))
# print(base_path)
data_dir = os.path.abspath(os.path.join(BASE_PATH, '..', 'RMB_data'))
random.seed(1)
# print(data_dir)
test_label = {"1": 0, "100": 1}
data_info = list()
for path, dirs, files in os.walk(base_path):
    for sub_dir in dirs:
        # print(sub_dir)
        sub_dirlist = os.listdir(os.path.join(base_path, sub_dir))
        pynames = list(filter(lambda y: y.endswith('.jpg'), sub_dirlist))
        # print(pynames)
        # print(test_label[sub_dir])
        for pyname in pynames:
            datainfo_dir = os.path.join(base_path, sub_dir, pyname)
            t_label=test_label[sub_dir]
            t_label = int(t_label)
            data_info.append((datainfo_dir, t_label))
# print(data_info)
new_data_info = list()
for data_info_e in data_info:
    x_dir, x_label = data_info_e
    x_img = Image.open(x_dir).convert('RGB')
    ok_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
    ])
    x_img = ok_transform(x_img)
    new_data_info.append((x_img,x_label))

# print(len(new_data_info[0][0]))
print(len(new_data_info))
newdataLoader = DataLoader(new_data_info,batch_size=14, shuffle=True)
for ids, data in enumerate(newdataLoader):
    print(ids)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值