#图像分类数据集
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
#通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式
#并除以225使得所有像素的数值均在0和1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="./data",train=True,transform=trans,download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="./data",train=False,transform=trans,download=True)
print(len(mnist_train),len(mnist_test))
print(mnist_train[0][0].shape)
def get_fashion_mnist_labels(labels):
#返回Fashion-MNIST数据集的文本标签
text_labels = ['t-shirt','trouser','pullover','dress','coat',
'sandal','shirt','sneaker','bag','ankle boot']
return [text_labels[int(i)] for i in labels]
def show_images(imgs,num_rows,num_cols,title=None,scale=1.5):
#plot a list of images
figsize = (num_cols*scale,num_rows*scale)
_,axes = d2l.plt.subplots(num_rows,num_cols,figsize=figsize)
axes = axes.flatten()
for i,(ax,img) in enumerate(zip(axes,imgs)):
ax.set_title(title[i])
if torch.is_tensor(img):
#图片张量
ax.imshow(img.numpy())
else:
#PIL图片
ax.imshow(img)
x,y = next(iter(data.DataLoader(mnist_train,batch_size=18)))
show_images(x.reshape(18,28,28),2,9,title=get_fashion_mnist_labels(y))
#d2l.plt.show()
batch_size = 256
def get_dataloader_workers():
#使用4个进程来读取数据
return 4
train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
timer = d2l.Timer()
for x,y in train_iter:
continue
print(f'{timer.stop():.2f} sec')
def load_data_fashion_mnist(batch_size,resize=None):
#下载Fashion-MNIST数据集,然后将其加载到内存中
trans = [transforms.ToTensor()]
if resize:
trans.insert(0,transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(root="./data",
train=True,
transform=trans,
download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="./data",
train=False,
transform=trans,
download=True)
return (data.DataLoader(mnist_train,batch_size,shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test,batch_size,shuffle=True,
num_workers=get_dataloader_workers()))
深度学习09图片数据集
最新推荐文章于 2024-08-20 07:03:46 发布