加载自己的数据集

Dataset类时pytorch图像数据集中最重要的一个类, 是pytorch所有数据集加载应该继承的父类;

若要加载自己的数据集,Dataset中的两个私有成员函数必须重新编写:

    def __getitem__(self, index):
    def __len__(self):

关于__getitem__、setitemdelitemlen【点击此处】

getitem函数:
接收的index是一个list的index,这个list的每个元素包含 图片的路径和标签;
返回图片数据和标签;

len函数:
返回数据集的大小;

list的制作:
将所有图片的路径和标签存储在一个txt中,
如:txt每行包括一个样本数据的路径和标签,逐行读取,放入list中,即可;
下面演示一种简单情况:假设不同类别图像在不同文件夹中,文件夹已编好序号(0,1,2,3,4,5),制作这种txt文件代码如下(具体要按照自己的数据集形式进行调整):

import os

a = 0
while (a < 6):  # 6为类别数(六个类别为0,1,2,3,4,5)
    dir = './data/test/%d' % a  # 图片文件的地址
    label = a
    files = os.listdir(dir)  # 列出dirname下的目录和文件,list集
    train = open('./data/train.txt', 'a')
    text = open('./data/text.txt', 'a')

    i = 0
    for file in files:
        if i < 20:  # 训练集中每类图片有20张(每类其余图片做测试集)
            fileType = os.path.split(file)  # os.path.split():按照路径将文件名和路径分割开
            if fileType[1] == '.txt':
                continue
            name = str(dir) + file + ' ' + str(int(label)) + '\n'
            train.write(name)
            i = i + 1

        else:
            fileType = os.path.split(file)
            if fileType[1] == '.txt':
                continue
            name = str(dir) + file + ' ' + str(int(label)) + '\n'
            text.write(name)
            i = i + 1
    text.close()
    train.close()
    a += 1

运行得到的txt文件内容如下(截取test.txt其中一部分):
在这里插入图片描述

加载自己的数据集整体代码:

import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


def default_loader(path):
    return Image.open(path).convert('RGB')


# 首先自己构建一个MyDataset类
class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        super(MyDataset, self).__init__()
        fh = open(txt, 'r')
        images = []
        for line in fh:
            line = line.strip('\n')
            line = line.rsplit()
            words = line.split()   # 将该行分隔成列表
            images.append((words[0], int(words[1])))
        self.imgs = images
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is None:
            img = torch.from_numpy(img)
        else:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)


transform = transforms.Compose([
    transforms.Scale((227, 227)),   # 将所有图片resize到统一的尺寸
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])   # 归一化
])

train_data = MyDataset(txt='traindata.txt', transform=transform)
train_loader = DataLoader(
    dataset=train_data,
    batch_size=50,
    shuffle=True,
    num_workers=2
)
  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值