''' 3.5图像分类数据集Fashion-MNIST数据集 ''' import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt #读取数据集:网上下载 #ToTensor由torchvision.transforms导入 fashion_train=torchvision.datasets.FashionMNIST(root="./dataset",train=True,download=True,transform=transforms.ToTensor()) fashion_test=torchvision.datasets.FashionMNIST(root="./dataset",train=False,download=True,transform=transforms.ToTensor()) #print(len(fashion_train),len(fashion_test))#长度60000,10000 #OSError: [WinError 126] 找不到指定的模块。解决办法:之前线性回归的从零实现将dll文件删除,需要重新放回去 #TypeError: __init__() takes 1 positional argument but 2 was given。解决办法:上面的ToTensor忘了加()。 #feature,label=fashion_train[0] #print(feature.shape,label)#查看第一个图像的通道数、高度、宽度,标签---torch.Size([1, 28, 28]) 9 #print(fashion_train.classes)#输出数据集中的类别属性,共10个类别 #创建函数——数字标签索引及其文本名称之间进行转换 def get_fashion_mnist_lbels(labels): text_labels=['t-shirt','trouser','pullover','dress','coat','sandal','shirt','sneaker','bag','ankle boot'] return text_labels[int(labels)] #创建函数可视化 #展示2行9列的图片 _,axes=plt.subplots(nrows=2,ncols=9,sharex=True,sharey=True) axes=axes.flatten() for i in range(18): img=fashion_train.data[i] axes[i].imshow(img) axes[i].set(title=get_fashion_mnist_lbels(fashion_train[i][1])) axes[0].set_xticks([]) axes[0].set_yticks([]) plt.tight_layout() plt.show()
图像分类数据集Fashion-MNIST数据集(2行9列复现)
于 2023-10-26 11:03:34 首次发布