Pytorch 09—Dataset数据输入

前面讲过的图片输入方式是从图片的文件夹来读取图片的一种方式。但是必须将类别单独放在一个文件夹。我们现在创建Dataset的子类来进行输入。

  • 必须继承自data.Dataset
  •  __getitem__ 方法必须创建,只要有这个方法,我们就可以进行切片
  •  __len__ 必须被实现,有了这个方法,我们就可以使用len方法返回数据集的长度
import torch
from torch.utils import data
from PIL import Image   #  pip install pillow
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline
import glob # 可以获取某一个条件下的所有的路径


# 自定义输入Dataset类
class MyDataset(data.Dataset):
    def __init__(self,imgsPath):
        self.imgs_path=imgsPath    # 图片路径
    def __getitem__(self,index):
        return self.imgs_path[index]
    def __len__(self):
        return len(self.imgs_path)
    
# 获取所有图片的路径
all_imgs_path = glob.glob(r'E:\Codes\Python\PyTorch\dataset2\*.jpg')    # 获取路径下,所以以.jpg结尾的图片路径。在该目录下,是四种天气的所有图片
#for i in range(5):
#   print(all_imgs_path[i])
weather_dataset = MyDataset(all_imgs_path)
len(weather_dataset)    # 1122。由于My_dataset内部实现了 __len__ 方法,所以可以使用len方法
    

创建Dataloader:

from torch.utils.data import DataLoader
wh_dl = DataLoader(weather_dataset,batch_size=4)
next(iter(wh_dl)) # 返回一个批次的数据(返回4张图片的路径)。列表形式

获取图片的标签:

我们获取了图片路径后,要获取它对应得标签值。

species = ['cloudy', 'rain', 'shine', 'sunrise']
# 将这4个类别使用数值型进行编码
species_to_idx = dict((c, i) for i, c in enumerate(species)) # cloudy:0,rain:1,....
print(species_to_idx)

idx_to_species = dict((v, k) for k, v in species_to_idx.items()) # 字典的items方法会以元祖形式返回key value对象
idx_to_species

 

# 提取所有图片的标签
all_labels = []
for img in all_imgs_path:   # img就是一张图片的路径
    for i, c in enumerate(species): # species = ['cloudy', 'rain', 'shine', 'sunrise']
        if c in img:
            all_labels.append(i)    # 以数值形式代表标签

 


划分数据集:

index = np.random.permutation(len(all_imgs_path)) # 将所有图片的长度做一个乱序处理
index

all_imgs_path = np.array(all_imgs_path)[index]  # 需要先将all_imgs_path转换array才能索引
all_labels = np.array(all_labels)[index]
s = int(len(all_imgs_path)*0.8) # 百分之八十作为训练数据集
train_imgs = all_imgs_path[:s]
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_labels[s:]

transform = transforms.Compose([
                    transforms.Resize((96, 96)),
                    transforms.ToTensor(),  # 将图片数据在0~1之间,0维度是channel
])

创建输入:下面要开始创建输入(上面只是演示)。

class Mydataset(data.Dataset):
    def __init__(self, img_paths, labels, transform):
        self.imgs = img_paths
        self.labels = labels
        self.transforms = transform
    
    # 对于图片的读取和转换,我们也都放在__getitem__里面,当给一个索引时,返回的是图片对象,而不再是图片的路径了
    # 所以要对图片进行读取和转换。
    def __getitem__(self, index):
        img = self.imgs[index]
        label = self.labels[index]
        
        pil_img = Image.open(img)    
        data = self.transforms(pil_img)
        
        return data, label # 返回图片对象和对应的标签
    
    def __len__(self):
        return len(self.imgs)

BATCH_SIZE = 16
weather_dataset = Mydataset(all_imgs_path,all_labels,transform) 
# 类型是 torch.utils.Dataset对象
weather_dl = data.DataLoader(weather_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=0)  # num_workers表示使用多少个进程读取,0表示不管它

# 取出一个批次的数据
imgs_batch,lables_batch = next(iter(weather_dl))    
print(imgs_batch.shape) # imgs_batch的shape是:torch.Size([16,3,96,96])
print(lables_batch.shape) # torch.Size([16]))

plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs_batch[-6:], lables_batch[-6:])):
    img = img.permute(1, 2, 0).numpy()  # 交换每一维的数据,要将图片的channel(3)放在最后,再转换为ndarray的类型
    plt.subplot(2, 3, i+1)
    plt.title(idx_to_species.get(label.item())) # label是一个tensor,单个tensor获取标量使用item方法
    plt.imshow(img)


创建Dataloader:

通过定义子类的方式创建图片的输入,这种方式不仅可以应用于图片,还可以应用于csv、ndarray数据都可以 。

 

train_ds = Mydataset(train_imgs,train_labels,transform)
test_ds = Mydataset(test_imgs,test_labels,transform)
train_dl = data.DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True)
test_dl = data.DataLoader(test_ds,batch_size=BATCH_SIZE,shuffle=False)
imgs, labels = next(iter(train_dl))
imgs.shape,labels.shape


灵活的使用Dataset类构造输入: 比如train_dl已经创建好了,我要用在tf里面,通道数要放在后面,即将每个批次的数据变为[16, 96, 96,3],通过创建子类的方式可以实现。

class New_dataset(data.Dataset):
    def __init__(self, some_dataset): # 输入的是已有的Dataset
        self.ds = some_dataset
    def __getitem__(self, index):
        img, label = self.ds[index] # 返回的是对应的图片对应和标签。返回的是单个图片,没有批次大小,即[3,96,96]
        img = img.permute(1, 2, 0) # 通道数变换。
        return img, label
    def __len__(self):
        return len(self.ds)


train_new_dataset = New_dataset(train_ds)
img, label = train_new_dataset[2]
img.shape,label.shape # 现在是:h*w*c

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

心之所向便是光v

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值