神经网络搭建(Pytorch)——创建自己的数据集并重写Dataset类

参考文章:https://cloud.tencent.com/developer/article/1433735

数据集目录
– data
 – train
  – 0.jpg
  – 1.jpg
  …
 – test
  – 0.jpg
  – 1.jpg
  …
 – train.txt
 – test.txt

train.text内容

DataSet重写代码

  1. 加载数据
  2. 重写 getitem() 函数
  3. 重写 len() 函数
from torch.utils.data import Dataset
import os
from PIL import Image


class MyDataset(Dataset):
    def __init__(self, root, train=True, transform=None):
        super(MyDataset, self).__init__()
        self.root = root
        self.train = train
        self.transform = transform

        if train:
            file = open(root + 'train.txt', 'r')
        else:
            file = open(root + 'test.txt', 'r')

        imgs = []
        for line in file:    # 按行循环 txt 文本
            line = line.rstrip()    # 删除本行字符串末尾的 '\n'
            word = line.split()    # 将图片图片路径和图片标签分开
            imgs.append((word[0], int(word[1])))    # 把 txt 中的内容保存到 imgs 列表
        self.imgs = imgs

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(self.root+fn).convert('RGB')    # 按路径读取图片

        if self.transform is not None:
            img = self.transform(img)    # 对图片进行 transform 操作
        return img, label

    def __len__(self):
        return len(self.imgs)

使用重写的Dataset类

from torch.utils.data import DataLoader
from torchvision import transforms


if __name__ == '__main__':
	tran = transforms.Compose([
		transforms.Resize((256, 256)),
		transforms.ToTensor()
	])
	tr_set = MyDataset(root='./data/', train=True, tran)
	tr_loader = DataLoader(
		dataset=tr_set,
		shuffle=True,
		batch_size=32
	)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值