图像分类数据集Fashion-MNIST数据集(2行9列复现)

'''
3.5图像分类数据集Fashion-MNIST数据集

'''
import torch

import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

#读取数据集:网上下载
#ToTensor由torchvision.transforms导入
fashion_train=torchvision.datasets.FashionMNIST(root="./dataset",train=True,download=True,transform=transforms.ToTensor())
fashion_test=torchvision.datasets.FashionMNIST(root="./dataset",train=False,download=True,transform=transforms.ToTensor())
#print(len(fashion_train),len(fashion_test))#长度60000,10000
#OSError: [WinError 126] 找不到指定的模块。解决办法:之前线性回归的从零实现将dll文件删除,需要重新放回去
#TypeError: __init__() takes 1 positional argument but 2 was given。解决办法:上面的ToTensor忘了加()。

#feature,label=fashion_train[0]
#print(feature.shape,label)#查看第一个图像的通道数、高度、宽度,标签---torch.Size([1, 28, 28]) 9
#print(fashion_train.classes)#输出数据集中的类别属性,共10个类别
#创建函数——数字标签索引及其文本名称之间进行转换
def get_fashion_mnist_lbels(labels):
    text_labels=['t-shirt','trouser','pullover','dress','coat','sandal','shirt','sneaker','bag','ankle boot']
    return text_labels[int(labels)]
#创建函数可视化
#展示2行9列的图片
_,axes=plt.subplots(nrows=2,ncols=9,sharex=True,sharey=True)
axes=axes.flatten()
for i in range(18):
    img=fashion_train.data[i]
    axes[i].imshow(img)
    axes[i].set(title=get_fashion_mnist_lbels(fashion_train[i][1]))
axes[0].set_xticks([])
axes[0].set_yticks([])
plt.tight_layout()
plt.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值