pytorch 加载自己的数据集
pytorch 加载自己的数据集,需要写一个继承自torch.utils.data中Dataset类,并修改其中的__init__方法、__getitem__方法、__len__方法。默认加载的都是图片,__init__的目的是得到一个包含数据和标签的list,每个元素能找到图片位置和其对应标签。然后用__getitem__方法得到每个元素的图像像素矩阵和标签,返回img和label。
以加载一个图像放在某个文件夹下,并在当前目录下生成了一个.txt的文件,大致如下train、test文件夹下放图片,test.txt和train.txt以如下方式存放图片路径和标签:
import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root = "/home/zlab/zhangshun/torch1/data_et/"
# -----------------ready the dataset--------------------------
def default_loader(path):
return Image.open(path).convert('RGB')
class MyDataset (Dataset):
# 构造函数带有默认参数
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
# 移除字符串首尾的换行符
# 删除末尾空
# 以空格为分隔符 将字符串分成
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))#imgs中包含有图像路径和标签
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
#调用定义的loader方法
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
train_data = MyDataset(txt=root + 'train.txt', transform=transforms.ToTensor())
test_data = MyDataset(txt=root + 'test.txt', transform=transforms.ToTensor())
#train_data 和test_data包含多有的训练与测试数据,调用DataLoader批量加载
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)