关于torchvision加载数据集的小问题

现有以下完整程序可以成功加载数据集,使用ImageFolder函数:

import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms,utils
import numpy as np
# 使用ImageFolder需要保证数据集以下列形式组织:
'''
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
'''
img_data = torchvision.datasets.ImageFolder(
    root = r'E:\机器学习数据集\flower_photos',
    transform=transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()])
        )
print('数据集类别:',img_data.classes)
print('数据集大小:',len(img_data))

# 使用torch.utils.data.DataLoader加载,形成一个DataLoader类实例
data_loader = torch.utils.data.DataLoader(img_data,batch_size=36, shuffle=True)
print(len(data_loader))

def imshow(img):
#    img = img / 2 + 0.5     # unnormalize
    img = torchvision.utils.make_grid(img, nrow=6)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title('Batch from dataloader')
    plt.xticks([])
    plt.yticks([])
    plt.show()

# get some random training images
dataiter = iter(data_loader)
images, labels = dataiter.next()
print(images.shape, labels)
# show images
imshow(images)

上面程序用了三种变换:Resize,Crop和ToTensor,问题就出现在这里了,

  1. 问题1:如果去掉Crop,只留下Resize和ToTensor,程序报错:
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 408 and 341 in dimension 3 at ..\aten\src\TH/generic/THTensor.cpp:711
  1. 问题2 :去掉ToTensor或者将ToTensor放到前面而不是最后一项,程序报错:
TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.Image.Image'>

其他情况不好有问题,先留着这两个问题,以后研究深入了再解答。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值