实战分类模型之数据读取与展示

fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
#训练集有60000张图片,前5000张图片作为验证集,后55000作为训练集
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]

print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

#输入是一个image的numpy数组
#显示一张图片
def show_single_image(img_arr):
    plt.imshow(img_arr, cmap="binary")
    plt.show()

show_single_image(x_train[1])

def show_imgs(n_rows, n_cols, x_data, y_data, class_names):
    #y_data里存的是类别名的索引
    #显示多张图片
    assert len(x_data) == len(y_data)
    assert n_rows * n_cols < len(x_data)
    #定义一张大图
    plt.figure(figsize = (n_cols * 1.4, n_rows * 1.6))
    for row in range(n_rows):
        for col in range(n_cols):
            #当前行*总列数+列的偏移量
            index = n_cols * row + col 
            plt.subplot(n_rows, n_cols, index+1)  #在大图上画一张子图
            plt.imshow(x_data[index], cmap="binary",  
                       interpolation = 'nearest')
            plt.axis('off') #关闭坐标
            plt.title(class_names[y_data[index]])
    plt.show()
class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress',
               'Coat', 'Sandal', 'Shirt', 'Sneaker',
               'Bag', 'Ankle boot']
show_imgs(3, 5, x_train, y_train, class_names)

note:如标题所示,本文代码着重于数据读取与展示,并未包含模型构建与训练部分,如果您想要关注其它部分的代码,请参照主页的其他博客

本文使用的数据集是keras里的资源fashion_mnist数据集

plt.subplot函数

它的调用是这样子的:subplotnumbRow numbCol plotNum or  subplot(numbRow numbCol plotNum),注意:可以不用逗号分开直接写在一起也是对的;

   numbRowplot图的行数;numbColplot图的列数;plotNum是指第几行第几列的第几幅图

   举个例子,如果是subplot 2 2 1),那么这个figure就是个2*2的矩阵图,也就是总共有4个图,1就代表了第一幅图

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值