Pytorch学习之数据加载
一、Dataset类
这个类可以看成是自定义的数据集类(是一个抽象类,不能直接实例化,只能继承)
代码如下(示例):
class Mydataset(Dataset):
def __init__(self,):
pass
def __len__(self):
pass
def __getitem__(self,idx)
pass
一、当数据集比较小时,可以把整个数据集放入init中(即放入内存中),再根据getitem的索引来读出
二、当数据集比较大时(如图像数据集),一般要先做一个列表,来记录下每张图像的id。在getitem函数里读取列表中第i个图像id,系统会从文件夹中将图片读出,返回
二、torchvision.transforms.Compose使用
这个类的主要作用是串联多个图片变换的操作。Compose里面的参数实际上就是个列表
通常预处理步骤:
- 所有图片转化为相同大小。
- 把图片数据集转换为Pytorch张量
- 用数据集的均值和标准差把数据集归一化
代码如下(示例):
transforms = transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
三、torchvision.datasets.ImageFolder使用详解
ImageFolder是一个通用的数据加载器,数据如放在文件夹中的图片
使用详情
dataset=torchvision.datasets.ImageFolder(
root,
transform=None,
target_transform=None,
loader=<function default_loader>,
is_valid_file=None)
1.参数详解
root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…即label是按照文件夹命名从0开始的数字
loader:表示数据集加载方式,通常默认加载方式即可。
is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)
2.返回的dataset都有以下三种属性:
self.classes:用一个 list 保存类别名称
self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
self.imgs:保存(img-path, class) tuple的 list
代码如下(示例):
train_dataset = datasets.ImageFolder(root=./data/train,
transform=transforms)
我们得到的train_dataset,它的结构就是[(img_data,class_id),(img_data,class_id),…]
print(train_dataset[995]) # 第995个图片 class_id=1
'''
输出:
(tensor([[[-0.1765, -0.1686, -0.1686, ..., -0.2941, -0.2941, -0.3020],
[-0.1765, -0.1765, -0.1608, ..., -0.2941, -0.2941, -0.2863],
[-0.1765, -0.1765, -0.1608, ..., -0.2863, -0.2863, -0.2784],
...,
[-0.2078, -0.1922, -0.1843, ..., -0.1608, -0.1608, -0.1608],
[-0.1608, -0.1922, -0.1843, ..., -0.1608, -0.1608, -0.1608],
[-0.1922, -0.1686, -0.2000, ..., -0.1686, -0.1608, -0.1529]],
[[-0.2392, -0.2314, -0.2314, ..., -0.3176, -0.3176, -0.3176],
[-0.2392, -0.2392, -0.2235, ..., -0.3176, -0.3098, -0.3020],
[-0.2392, -0.2392, -0.2235, ..., -0.3176, -0.3176, -0.3098],
...,
[-0.3490, -0.3569, -0.3333, ..., -0.3020, -0.3020, -0.3020],
[-0.3098, -0.3412, -0.3333, ..., -0.3020, -0.3020, -0.3020],
[-0.3490, -0.3098, -0.3490, ..., -0.3098, -0.3020, -0.2941]],
[[-0.7255, -0.7176, -0.7176, ..., -0.8745, -0.8824, -0.8824],
[-0.7255, -0.7255, -0.7098, ..., -0.8745, -0.8902, -0.8824],
[-0.7255, -0.7255, -0.7098, ..., -0.8588, -0.8745, -0.8667],
...,
[-0.8353, -0.8745, -0.7882, ..., -0.6784, -0.6784, -0.6784],
[-0.7882, -0.8588, -0.7882, ..., -0.6784, -0.6784, -0.6784],
[-0.8039, -0.7882, -0.8039, ..., -0.6941, -0.6784, -0.6706]]]), 1)
'''
再看三个属性
print(dataset.classes) #根据分的文件夹的名字来确定的类别
print(dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
print(dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
'''
输出:
['cat', 'dog']
{'cat': 0, 'dog': 1}
[('./data/train\\cat\\1.jpg', 0),
('./data/train\\cat\\2.jpg', 0),
('./data/train\\dog\\1.jpg', 1),
('./data/train\\dog\\2.jpg', 1)]
'''
四、按批加载数据-----DataLoader类
数据集过大不能一次性全部加载到内存里,可按批次来加载数据
使用详情
train_loader = torch.utils.data.DataLoader(train_dataset, # 导入的训练集
batch_size=4, # 每批训练的样本数
shuffle=True, # 是否打乱训练集
num_workers=0) # 使用线程数,在windows下设置为0。多线程能提高读取效率