深度学习基础:如何使用自己的数据集
通过继承torch的Dataset类,来实现加载自己的数据集。
本文以ISIC2018数据集(这个是开源的数据集)为例。
需要重写的三个函数
import torch
# 数据加载器
class Reader(torch.utils.data.Dataset): # 数据读取
"""
读取数据
"""
def __init__(self):
super().__init__()
# 这里可以进行一些数据的预处理,比如类型转换、数据增强等。
# 一般都在在这里生成两列表,一个是所有输入数据的列表,另一个是所有标签的列表,它们是一一对应的
pass
def __getitem__(self, item):
pass
# 这里需要返回数据集对应item的两Tenser类型的数据,一个是输入数据,另一个是标签数据
def __len__(self):
# 返回数据集长度
pass
使用自己的数据集
这里是图像预处理部分。根据需要对数据进行处理,实现数据增强的效果。
def get_new_data(img, label):
"""
数据预处理:把图像矩阵填充为方阵。
:param img: Image类型,输入特征图
:param label: Image类型,单通道
:return:
"""
try:
# 这个异常要不要都行,防止有一些图像处理不了程序直接全部异常退出的。
img_arr = np.array(img)
max_size = max(img.size)
# 填充图像矩阵成为(max_len, max_len)
img_arr = np.pad(img_arr, ((0, max_size - img_arr.shape[0]), (0, max_size - img_arr.shape[1]), (0, 0)),
'constant', constant_values=255)
new_img = Image.fromarray(img_arr).convert('RGB')
label_arr = np.array(label)
label_arr = np.pad(label_arr, ((0, max_size - label_arr.shape[0]), (0, max_size - label_arr.shape[1])),
'constant', constant_values=0)
new_label = Image.fromarray(label_arr).convert('L')
return new_img, new_label
except:
return img, label
Dataset类
import os
import numpy as np
import torch
from PIL import Image
# 数据集读取
class Reader(torch.utils.data.Dataset): # 数据读取
"""
读取数据
"""
def __init__(self, images_path, labels_path, transform):
"""
:param images_path: 图片地址
:param labels_path: 标签地址
:param transform: 类型转换对象
"""
super().__init__()
# 获取数据列表
# transform:将类型转换为Tenser类型。
self.transform = transform
# 用于保存原始图像信息,一般保存图像名称,需要用时根据名称打开图像。
datas = []
labels_list = os.listdir(labels_path) # 生成路径
# 读取图像原始信息,这里为图像地址
for i in labels_list: # 遍历图片存放的目录
if '.png' in i:
# 我使用的数据集标签和原图的名称是基本一样的,只是目录不同
label_path = os.path.join(labels_path, i)
image_path = os.path.join(images_path, i.split('_')[0] + '_' + i.split('_')[1] + '.jpg')
# print(label_path, image_path)
datas.append((image_path, label_path))
print("共{}个数据".format(datas.__len__()))
self.datas = datas
def __getitem__(self, item):
img_path, label_path = self.datas[item]
# 根据需要读取RGB图或者灰度图
img = Image.open(img_path)
label = Image.open(label_path).convert('L')
# 数据预处理
img, label = get_new_data(img, label)
# 将图片从Image类型转为Tensor类型
img = self.transform(img)
label = self.transform(label)
# 返回数据集对应item位置的图像和标签,都是Tensor类型
return img, label
def __len__(self):
# 返回数据集长度
return self.datas.__len__()
这样我们就设置好了一个Dataset。由于数据集格式不同,Dataset的数据读取方式也不同,需要灵活运用。
构造一个数据加载器
if __name__ == '__main__':
import torchvision
from torch.utils.data import DataLoader
transform = torchvision.transforms.Compose(
[torchvision.transforms.Resize((64, 64)), # 缩放
torchvision.transforms.ToTensor()]) # 类型转换
images_path = r"E:\数据集\1-2_Validation_Input\ISIC2018_Task1-2_Validation_Input"
label_path = r"E:\数据集\1_Validation_GroundTruth\ISIC2018_Task1_Validation_GroundTruth"
dateset = Reader(images_path, label_path, transform)
train_DataLoader = DataLoader( # 数据加载器
dataset=dateset, # 选择dataset
batch_size=3, # 选择batch大小
num_workers=0, # windows得设置这个,不然有时候会报错
shuffle=True) # 是否乱序
查看是否正常读取数据
for data in train_DataLoader:
img, label = data
print(img.shape, label.shape)
输出:
torch.Size([3, 3, 64, 64]) torch.Size([3, 1, 64, 64])
每个维度的含义[batch_size, 图像通道数, img_x_size, img_y_size]
batch_size:批次大小
图像通道数:比如RGB图像是三通道图像,灰度图是单通道图像