一、从下载数据集开始说起
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
通过上面代码下来的数据集的数据的shape是:
print(x_train.shape)
(60000, 28, 28)
print(y_train.shape)
(60000, 10)
print(x_test.shape)
(10000, 28, 28)
print(y_test.shape)
(10000, 10)
查过很多篇文章,几乎都是(28,28)变成(784,4)然后用全连接网络进行分类。而如果想要卷积神经网路来进行手写数据集的分类时,大多数时候传入网络的图像通道数都为3,那么这时候就需要把(60000,28,28)变为(60000,28,28,3)。然而这个变换过程怎么实现呢?
二、(28,28)变为(28,28,3)
我倒腾出来的方法有两个:
(1)首先是将(28,28)通过reshape,变为(28,28,1),再通过concatenate进行拼接;
(2)直接通过np.stack方法实现。这种方法是和别人交流得知。
import numpy as np
#方法(1)
x_281 = x_train.reshape(len(x_train), 28,28,1)
x_283 = np.concatenate((x_281, x_281, x_281), axis=-1)#注意,concatenate方法只能用于数据类型为整型的np数组
方法(2)
x_s283 = np.stack((x_train, x_train, x_train), axis=3)
需要注意的是两个方法的axis参数,一个是aixs=-1,一个是axis=3。至于这两种方法的使用说明,以及axis参数传入不同的值得到不同的shape,现在还没搞清楚,留个空白以后补充。
##############################################################
2020年6月17日23:01:56
这个空白可以现在补了。
比如说一个numpy数组的维度是(28,28,1),分别是第0个维度,第1个维度,第3个维度,如果是从最后一个维度开始数,分别是第-3个维度,第-2个维度,第-1个维度。
所以axis就是指定拼接是在哪个维度进行的。
总之我是理解了,不过表述不是很清楚。
还有一点,stack要求所有维度都要一样。
2020年6月27日 再更
原来x_train的维度是(60000, 28, 28),第0个维度是60000,第1个维度是28,第2个维度是28。
x_283 = np.stack((x_train, x_train, x_train), axis=3),得到的x_283的维度是(60000, 28, 28, 3)。
此时第0个维度是60000,第1个维度是28,第2个维度是28,第3个维度是3。过程差不多是像下面这张图:
##############################################################
结果查看:
import matplotlib.pyplot as plt
plt.figure(figsize=(10,4))
plt.subplot(1, 3, 1)
plt.imshow(x_train[0])
plt.legend('x_train[0]')
plt.xlabel('x_train[0]')
plt.title('x_train[0]')
plt.subplot(1, 3, 2)
plt.imshow(x_283[0])
plt.legend('x_283[0]')
plt.xlabel('x_283[0]')
plt.title('x_283[0]')
plt.subplot(1, 3, 3)
plt.imshow(x_s283[0])
plt.legend('x_s283[0]')
plt.xlabel('x_s283[0]')
plt.title('x_s283[0]')
plt.show()
三、结果分析
从结果查看的结果来看,把(28,28)到(28,28,3)的两种方法是可行的。