Dataset类时pytorch图像数据集中最重要的一个类, 是pytorch所有数据集加载应该继承的父类;
若要加载自己的数据集,Dataset中的两个私有成员函数必须重新编写:
def __getitem__(self, index):
def __len__(self):
关于__getitem__、setitem、delitem、len【点击此处】
getitem函数:
接收的index是一个list的index,这个list的每个元素包含 图片的路径和标签;
返回图片数据和标签;
len函数:
返回数据集的大小;
list的制作:
将所有图片的路径和标签存储在一个txt中,
如:txt每行包括一个样本数据的路径和标签,逐行读取,放入list中,即可;
下面演示一种简单情况:假设不同类别图像在不同文件夹中,文件夹已编好序号(0,1,2,3,4,5),制作这种txt文件代码如下(具体要按照自己的数据集形式进行调整):
import os
a = 0
while (a < 6): # 6为类别数(六个类别为0,1,2,3,4,5)
dir = './data/test/%d' % a # 图片文件的地址
label = a
files = os.listdir(dir) # 列出dirname下的目录和文件,list集
train = open('./data/train.txt', 'a')
text = open('./data/text.txt', 'a')
i = 0
for file in files:
if i < 20: # 训练集中每类图片有20张(每类其余图片做测试集)
fileType = os.path.split(file) # os.path.split():按照路径将文件名和路径分割开
if fileType[1] == '.txt':
continue
name = str(dir) + file + ' ' + str(int(label)) + '\n'
train.write(name)
i = i + 1
else:
fileType = os.path.split(file)
if fileType[1] == '.txt':
continue
name = str(dir) + file + ' ' + str(int(label)) + '\n'
text.write(name)
i = i + 1
text.close()
train.close()
a += 1
运行得到的txt文件内容如下(截取test.txt其中一部分):
加载自己的数据集整体代码:
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
def default_loader(path):
return Image.open(path).convert('RGB')
# 首先自己构建一个MyDataset类
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
super(MyDataset, self).__init__()
fh = open(txt, 'r')
images = []
for line in fh:
line = line.strip('\n')
line = line.rsplit()
words = line.split() # 将该行分隔成列表
images.append((words[0], int(words[1])))
self.imgs = images
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is None:
img = torch.from_numpy(img)
else:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
transform = transforms.Compose([
transforms.Scale((227, 227)), # 将所有图片resize到统一的尺寸
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 归一化
])
train_data = MyDataset(txt='traindata.txt', transform=transform)
train_loader = DataLoader(
dataset=train_data,
batch_size=50,
shuffle=True,
num_workers=2
)