在火力教育群里问了大佬一手,虽然不懂为啥但是还是把问题解决了,贴上更改前后的代码对比以及大佬提供的代码,以供大家参考,希望可以带来帮助。
更改前
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from torch.autograd import Variable
transforms_train = transforms.Compose([transforms.Resize(225),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],
[0.5, 0.5, 0.5])])
transforms_test = transforms.Compose([transforms.Resize(225),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],
[0.5, 0.5, 0.5])])
batch_sizes = 64
test_data_dir = './FIRE-SMOKE-DATASET/Test'
train_data_dir = './FIRE-SMOKE-DATASET/Train'
train_data = datasets.ImageFolder(root=train_data_dir, transform=transforms_train)
test_data = datasets.ImageFolder(root=test_data_dir, transform=transforms_test)
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_sizes, shuffle=True)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_sizes, shuffle=True)
images, labels = next(iter(train_data_loader))
def image_display(image, title=None):
image = image/2 + 0.5
numpy_image = image.numpy()
transposed_numpy_image = np.transpose(numpy_image, (1, 2, 0))
plt.figure(figsize=(20, 4))
plt.imshow(transposed_numpy_image)
plt.yticks([])
plt.xticks([])
if title:
plt.title(title)
plt.show
image_display(torchvision.utils.make_grid(images))
以上代码是运行出错的
下面更改后可以正确读取图片
batch_sizes = 64
test_data_dir = 'D:\FIRE-SMOKE-DATASET\Test'
train_data_dir = 'D:\FIRE-SMOKE-DATASET\Train'
train_data = dset.ImageFolder('D:\FIRE-SMOKE-DATASET\Train',transforms.Compose([transforms.ToTensor()]))
test_data = dset.ImageFolder('D:\FIRE-SMOKE-DATASET\Test',transforms.Compose([transforms.ToTensor()]))
train_data_loader = torch.utils.data.DataLoader(train_data)
test_data_loader = torch.utils.data.DataLoader(test_data)
iterator = iter(train_data_loader)
images, labels = next(iterator)
def image_display(image, title=None):
image = image/2 + 0.5
numpy_image = image.numpy()
transposed_numpy_image = np.transpose(numpy_image, (1, 2, 0))
plt.figure(figsize=(20, 4))
plt.imshow(transposed_numpy_image)
plt.yticks([])
plt.xticks([])
if title:
plt.title(title)
plt.show
image_display(torchvision.utils.make_grid(images))
结果
大佬提供的代码
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as dset
import numpy
import torchvision.transforms as transforms
# 方式:dset.ImageFolder(root="root folder path", [transform, target_transform])
dataset = dset.ImageFolder('./Assets',transforms.Compose([transforms.ToTensor()])) # 转化成tensor
train_data_loader = torch.utils.data.DataLoader(dataset)
iterator = iter(train_data_loader)
images, labels = next(iterator)
特别感谢火力教育大佬