pytorch学习(八)Dataset加载分类数据集

我们之前用torchvision加载了pytorch的网络数据集,现在我们用Dataset加载自己的数据集,并且使用DataLoader做成训练数据集。

图像是从网上下载的,网址是 点这里,标签是图像文件夹名字。下载完成后作为自己的数据集。

1.加载自己的数据集的思路

    1)要完成继承自Dataset的类的构建

          由于Dataset是一个包含了虚函数的类,因此继承Dataset后,必须实现这些虚函数。

   2)第一个要完成的是__init__的构建,一般的方法是在__init__(self,root_dir, label_dir)中设置数据集的根目录root_dir,和类别数据集label_dir,然后用os.listdir得到label_dir中的图像名字

    3)第二个要完成的就是

__getitem__(self, item):

       item就是所要取数据的索引,这个函数主要是返回一个训练数据(比如一个图像),和一个结果数据,比如(该图像的分类结果是一个ant),因此用到刚os.listdir所列出的文件名字,用os.path.join加入路径,得到图像的绝对路径,用PIL导入图像,并给label赋值,返回图像和;abel即可。

   4)第三个要实现的就是数据集的长度

  __len__(self):

可以直接len(os.listdir所列出的文件名的数组),就可以得到数据集的长度。

2.需要注意的问题

   我在调试的时候发现

for imgs, labels  in train_loader:

一直报错,查找原因,发现是该数据集中的图像存在两个问题,第一个是大小不一,第二个貌似通道个数也不一致。

大小不一

因此使用transform做了处理

transform=transforms.Compose([ transforms.Resize((320,320),interpolation=Image.BILINEAR),
                                transforms.Grayscale(),
                                transforms.ToTensor()])

3.代码如下:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import os
import torch
from torch.utils.tensorboard import SummaryWriter


writer = SummaryWriter("logs")
transform=transforms.Compose([ transforms.Resize((320,320),interpolation=Image.BILINEAR),
                                transforms.Grayscale(),
                                transforms.ToTensor()])


class MyDataLoader(Dataset):
    def __init__(self,root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, item):
        img_name = self.img_path[item]
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
        img = Image.open(img_item_path)
        img = transform(img)
        label = self.label_dir
        return img,label
    def __len__(self):
        return len(self.img_path)

root_dir = "E:/TOOLE/slam_evo/pythonProject/data/hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"

ants_dataset = MyDataLoader(root_dir,ants_label_dir)
bees_dataset = MyDataLoader(root_dir,bees_label_dir)
train_data = ants_dataset + bees_dataset

img0, label0 = train_data[12]
# img0.show()
img1, label1 = train_data[124]
# img1.show()
# 一次处理数据10个
BATCH_SIZE = 10
# 把数据集装载到DataLoader里
train_loader = DataLoader(train_data, shuffle=True, batch_size=BATCH_SIZE)

A = len(train_loader)
num_iter = 0
for imgs, labels  in train_loader:

    print(imgs.shape)
    print(labels)
    # print(train_data.classes)
    writer.add_images("ant-bees",imgs,num_iter)
    num_iter = num_iter +1

writer.close()


用tensorboard显示,batch_size= 10,因此每次迭代有10张图像

标签为:

  • 8
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值