论文《ImageNet Classification with Deep Convolutional Neural Networks》阅读及AlexNet的Tensorflow2复现
论文亮点
这周阅读了Hinton和他的学生Alex的经典论文《ImageNet Classification with Deep Convolutional Neural Networks》,总结亮点包括如下几部分:
- 提出了ReLU激活函数(ReLU, Rectified Linear Unit, 修正线性单元), R e L U ( x ) = m a x ( 0 , x ) ReLU(x)=max(0, x) ReLU(x)=max(0,x),由于该激活函数在趋近正无穷是也趋近正无穷,因此不会存在梯度消失的现象,可以加快训练过程。Alex将该函数称为非饱和函数(Non-Saturating Function)
- 双GPU加速,Alex首次使用双通道GPU,并在第三层进行全部连接,而剩余层则各自单独计算。
- 局部响应归一化(LRN, local Response Normalization)层的提出,对多个channel同一位置的数据进行归一化
- Pooling重叠,在该文中,AlexNet全部使用3*3的最大池化窗口,步长为2
- 数据增强方法,为防止过拟合,1)Alex首先使用翻转,将数据集扩大了2倍。然后用224224大小的窗口对256256的原始图片进行裁切,裁切过程中步长为1,形成32*32=1024个子图,从而再次扩大1024倍。因此将数据集扩大了2048倍。2)使用PCA对像素进行分析,改变图片的光照强度,进一步扩展数据集。
- DropOut的提出,通过使用DropOut方法,对全连接层的单元进行随机置零,从而避免了模型过拟合。
模型原理
AlexNet原理图如上,使用的双GPU进行训练。由于没有这么多资源,我完成了一半的模型,大小如图中所示。具体的每一层的参数,如下图所示。
代码复现
import os
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import tensorflow as tf
import time
定义卷积层
class MyConv2DLayer(tf.keras.layers.Layer):
'''
自定义卷积层
* `__init__()`: Save configuration in member variables
* `build()`: Called once from `__call__`, when we know the shapes of inputs
and `dtype`. Should have the calls to `add_weight()`, and then
call the super's `build()` (which sets `self.built = True`, which is
nice in case the user wants to call `build()` manually before the
first `__call__`).
* `call()`: Called in `__call__` after making sure `build()` has been called
once. Should actually perform the logic of applying the layer to the
input tensors (which should be passed in as the first argument).
'''
def __init__(self):
super().__init__()
def build(self, input_shape):
# 卷积核大小
self.w = self.add_weight(name='weights', shape=[5, 5, input_shape[-1], 128])
def call(self, input):
'''
input's default dimension order is [batch, height, width, channels]
filter's order is [filter_height, filter_width, in_channels, out_channels]
'''
output = tf.nn.conv2d(input=input, filters=self.w, strides=1, padding=[[0, 0], [2, 2], [2, 2], [0, 0]])
return output
定义LRN层
令
a
x
,
y
i
a_{x, y}^{i}
ax,yi表示在(x, y)这个位置上,以kernel i来计算后,经过激励后的输出。经过LRN层后,作为下一层输入的数据
b
x
,
y
i
b_{x, y}^{i}
bx,yi变为:
b
x
,
y
i
=
a
x
,
y
i
(
k
+
α
∑
j
=
m
a
x
(
0
,
i
−
n
/
2
)
m
i
n
(
N
−
1
,
i
+
n
/
2
)
(
a
x
,
y
i
)
2
)
β
b_{x, y}^{i} = \frac{a_{x, y}^{i}}{(k+\alpha\sum\limits_{j=max(0, i-n/2)}^{min(N-1, i+n/2)}{(a_{x, y}^{i})^2)^\beta}}
bx,yi=(k+αj=max(0,i−n/2)∑min(N−1,i+n/2)(ax,yi)2)βax,yi
其中,N表示改成的feature map(kernel)总数,n表示该feature map左右两边各n/2大小的窗口做平均。其余的几个参数如下:
k
=
2
k=2
k=2、
n
=
5
n=5
n=5、
α
=
1
0
−
4
\alpha=10^{-4}
α=10−4、
β
=
0.75
\beta=0.75
β=0.75
class myLRNLayer(tf.keras.layers.Layer):
'''
自定义LRN归一化层
'''
def __init__(self, alpha, beta, k, n):
super().__init__()
self.alpha = alpha
self.beta = beta
self.k = k
self.n = n
def build(self, input_shape):
print("SHAPE: ", input_shape)
def call(self, input):
'''
input's dimension order is [batch, height, width, channel]
具体操作为沿着channel的方向向前、向后取n/2个channel进行平方加和,乘以apha、beta等系数之后进行归一化,此处略,具体操作使用BN进行替换
'''
return input / (self.k + self.alpha * input) ** self.beta
# x = tf.ones([3, 5]) * 2
# m = myLRNLayer(1, 2, 1, 2)
# m(x)
数据预处理
使用Caltech-256数据集进行实验,由于整个数据集1G多,笔记本跑不起来,硬盘中原始数据的存储格式如下所示:
我们使用其中的1/6作为数据集进行实验,其中一共包含45个分类,共5102张图片。将其中4000张图片作为训练集,剩余的作为测试集进行实验。
PATH = 'D:\\迅雷下载\\Caltech-256\\data\\256_ObjectCategories'
dirNames = os.listdir(PATH)
labels = [int(dirName.split('.')[0]) for dirName in dirNames]
picDirs = [os.path.join(PATH, dirName) for dirName in dirNames]
sum = 0
for iter, typeDir in enumerate(picDirs):
tempSum = 0
for picName in os.listdir(typeDir):
sum += 1
tempSum +=1
print(tempSum, end='\t')
print("\nTOTAL NUMBER OF THE DATASET IS: %d" % sum)
输出如下:
98 97 151 127 148 90 106 232 102 94 278 216 98 86 122 91 104 101 124 83 142 97 110 112 114 106 100 110 103 104 90 101 102 100 87 106 120 110 85 124 87 87 124 121 85 133 94 103 106 97 114 85 82 118 98 102 106 93 83 87 102 83 122 131 101 83 83 110 99 84 99 118 100 115 83 84 92 90 99 116 95 81 95 84 112 80 93 98 110 212 95 201 112 104 86 285 89 100 80 93 138 88 111 97 270 87 89 85 156 85 84 84 116 120 88 107 121 108 87 130 82 103 111 91 101 242 128 105 190 91 92 190 136 119 93 93 156 192 86 89 117 107 130 82 798 82 202 174 103 111 109 120 93 103 92 96 105 149 209 83 103 91 90 101 88 92 86 140 92 102 84 99 106 84 83 110 96 98 80 102 100 120 100 84 103 81 95 88 119 112 111 112 174 112 87 104 100 109 105 100 81 97 91 80 87 98 115 109 102 111 95 136 101 139 84 98 105 81 84 94 103 91 80 110 90 99 147 95 95 94 112 358 100 122 114 97 90 97 84 201 95 93 90 91 91 101 92 84 100 96 800 116 435 95 103 108 827
TOTAL NUMBER OF THE DATASET IS: 30608
读取数据
# 数据集大小
LEN = int(np.ceil(sum / 6))
# 训练集大小, 根据数据集大小确定
TRAIN_SIZE = 4000
# 分类数目,决定softmax最终的输出层节点数
CATEGORY_NUM = 45
imgArray = np.zeros(shape = [LEN, 227, 227, 3], dtype=np.int16)
labelArray = np.empty(shape = [LEN], dtype=np.int32)
index = 0
for iter, typeDir in enumerate(picDirs):
print("Processing Label: %s" % typeDir)
for picName in os.listdir(typeDir):
if index >= LEN:
break
picPath = os.path.join(typeDir, picName)
img = plt.imread(picPath)
imgg = cv2.resize(img, (227, 227))
label = picName.split('.')[0].split('_')[0]
label = int(label)
try:
imgArray[index, :] = imgg
except Exception as e:
print("Error Occured. Probably this picture is a grayscale one...")
imgg = np.stack((imgg, imgg, imgg), axis=-1)
imgArray[index, :] = imgg
labelArray[index] = label
index += 1
if index >= LEN:
break
print("Data Preprocessing Done.\nTotal Category: %3d\tTotal Number: %3d" % (iter + 1, index))
def preProcess(x, y):
'''
预处理,x转换为[0, 1)范围的小数
y处理为one-hot编码
'''
x = tf.cast(x, dtype=tf.float32) / 255.
y = tf.one_hot(y, depth=CATEGORY_NUM)
return x, y
输出如下:
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\001.ak47
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\002.american-flag
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\003.backpack
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\004.baseball-bat
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\005.baseball-glove
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\006.basketball-hoop
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\007.bat
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\008.bathtub
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\009.bear
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\010.beer-mug
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\011.billiards
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\012.binoculars
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\013.birdbath
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\014.blimp
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\015.bonsai-101
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\016.boom-box
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\017.bowling-ball
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\018.bowling-pin
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\019.boxing-glove
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\020.brain-101
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\021.breadmaker
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\022.buddha-101
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\023.bulldozer
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\024.butterfly
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\025.cactus
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\026.cake
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\027.calculator
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\028.camel
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\029.cannon
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\030.canoe
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\031.car-tire
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\032.cartman
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\033.cd
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\034.centipede
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\035.cereal-box
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\036.chandelier-101
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\037.chess-board
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\038.chimp
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\039.chopsticks
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\040.cockroach
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\041.coffee-mug
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\042.coffin
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\043.coin
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\044.comet
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\045.computer-keyboard
Data Preprocessing Done.
Total Category: 45 Total Number: 5102
对数据进行shuffle和batch操作:
# 设置同一个种子可以保证分割后data和label依旧一一对应
randomSeed = time.time()
tf.random.set_seed(randomSeed)
imgShuffled = tf.random.shuffle(imgArray)
print("Image Shuffling Done...")
tf.random.set_seed(randomSeed)
labelShuffled = tf.random.shuffle(labelArray)
print("Label Shuffling Done...")
imgShuffled = tf.cast(imgShuffled, dtype=tf.int16)
trainData, testData = tf.split(imgShuffled, [TRAIN_SIZE, LEN - TRAIN_SIZE], axis=0)
trainLabel, testLabel = tf.split(labelShuffled, [TRAIN_SIZE, LEN - TRAIN_SIZE], axis=0)
# batch处理
batchTrain = tf.data.Dataset.from_tensor_slices((trainData, trainLabel)).batch(128)
batchTest = tf.data.Dataset.from_tensor_slices((testData, testLabel)).batch(128)
batchTrain = batchTrain.map(preProcess)
batchTest = batchTest.map(preProcess)
定义模型
# 用于辅助确定模型输出shape
# x = tf.random.normal([1, 227, 227, 3])
alexNet = tf.keras.Sequential([
# 第一层:[卷积层-池化层-LRN层]
# 卷积核大小为48@11*11, 步长为4
tf.keras.layers.Conv2D(48, 11, 4, activation='relu'),
# 最大池化层大小为3*3, 步长为2
tf.keras.layers.MaxPooling2D((3, 3), 2),
# 归一化层
tf.keras.layers.BatchNormalization(),
# 第二层:[卷积层-池化层-LRN层]
# 卷积核大小为128@5*5, 步长为1, 同时需要指定padding,因此自己实现卷积层
MyConv2DLayer(),
# 最大池化层大小为3*3, 步长为2
tf.keras.layers.MaxPooling2D((3, 3), 2, padding='valid'),
tf.keras.layers.ReLU(),
# 归一化层
tf.keras.layers.BatchNormalization(),
#第三层:[只包含卷积层]
tf.keras.layers.Conv2D(192, kernel_size=3, strides=1, padding='same', activation='relu'),
#第四层:[只包含卷积层]
tf.keras.layers.Conv2D(192, kernel_size=3, strides=1, padding='same', activation='relu'),
# 第五层:[]
tf.keras.layers.Conv2D(128, kernel_size=3, strides=1, padding='same', activation='relu'),
tf.keras.layers.ReLU(),
tf.keras.layers.MaxPooling2D((3, 3), 2),
# 全连接层
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(2048, activation='relu'),
tf.keras.layers.Dropout(0.25),
# 全连接层
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(2048, activation='relu'),
tf.keras.layers.Dropout(0.25),
# 输出层
tf.keras.layers.Dense(CATEGORY_NUM),
tf.keras.layers.Softmax()
])
alexNet.build(input_shape=(None, 227, 227, 3))
alexNet.summary()
模型summary输出如下:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) multiple 17472
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple 0
_________________________________________________________________
batch_normalization (BatchNo multiple 192
_________________________________________________________________
my_conv2d_layer (MyConv2DLay multiple 153600
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 multiple 0
_________________________________________________________________
re_lu (ReLU) multiple 0
_________________________________________________________________
batch_normalization_1 (Batch multiple 512
_________________________________________________________________
conv2d_1 (Conv2D) multiple 221376
_________________________________________________________________
conv2d_2 (Conv2D) multiple 331968
_________________________________________________________________
conv2d_3 (Conv2D) multiple 221312
_________________________________________________________________
re_lu_1 (ReLU) multiple 0
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 multiple 0
_________________________________________________________________
flatten (Flatten) multiple 0
_________________________________________________________________
dense (Dense) multiple 9439232
_________________________________________________________________
dropout (Dropout) multiple 0
_________________________________________________________________
flatten_1 (Flatten) multiple 0
_________________________________________________________________
dense_1 (Dense) multiple 4196352
_________________________________________________________________
dropout_1 (Dropout) multiple 0
_________________________________________________________________
dense_2 (Dense) multiple 92205
_________________________________________________________________
softmax (Softmax) multiple 0
=================================================================
Total params: 14,674,221
Trainable params: 14,673,869
Non-trainable params: 352
_________________________________________________________________
模型训练
训练模型,共30个epoch
alexNet.compile(optimizer='adam',
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=['mse', 'accuracy'])
alexNet.fit(batchTrain, epochs=30)
模型训练的输出metric:
Epoch 1/30
32/32 [==============================] - 127s 4s/step - loss: 3.8214 - mse: 0.0218 - accuracy: 0.0698
Epoch 2/30
32/32 [==============================] - 117s 4s/step - loss: 3.4564 - mse: 0.0211 - accuracy: 0.1328
Epoch 3/30
32/32 [==============================] - 117s 4s/step - loss: 3.1594 - mse: 0.0203 - accuracy: 0.1898
Epoch 4/30
32/32 [==============================] - 117s 4s/step - loss: 2.9747 - mse: 0.0197 - accuracy: 0.2265
Epoch 5/30
32/32 [==============================] - 125s 4s/step - loss: 2.7330 - mse: 0.0188 - accuracy: 0.2743
Epoch 6/30
32/32 [==============================] - 118s 4s/step - loss: 2.4830 - mse: 0.0177 - accuracy: 0.3315
Epoch 7/30
32/32 [==============================] - 118s 4s/step - loss: 2.2459 - mse: 0.0167 - accuracy: 0.3842
Epoch 8/30
32/32 [==============================] - 118s 4s/step - loss: 2.0411 - mse: 0.0158 - accuracy: 0.4290
Epoch 9/30
32/32 [==============================] - 145s 5s/step - loss: 1.8709 - mse: 0.0149 - accuracy: 0.4737
Epoch 10/30
32/32 [==============================] - 160s 5s/step - loss: 1.6553 - mse: 0.0135 - accuracy: 0.5280
Epoch 11/30
32/32 [==============================] - 144s 5s/step - loss: 1.4724 - mse: 0.0122 - accuracy: 0.5840
Epoch 12/30
32/32 [==============================] - 153s 5s/step - loss: 1.3277 - mse: 0.0114 - accuracy: 0.6168
Epoch 13/30
32/32 [==============================] - 160s 5s/step - loss: 0.9913 - mse: 0.0089 - accuracy: 0.7117
Epoch 14/30
32/32 [==============================] - 136s 4s/step - loss: 0.9707 - mse: 0.0085 - accuracy: 0.7258
Epoch 15/30
32/32 [==============================] - 120s 4s/step - loss: 0.8679 - mse: 0.0078 - accuracy: 0.7490
Epoch 16/30
32/32 [==============================] - 116s 4s/step - loss: 0.7359 - mse: 0.0068 - accuracy: 0.7862
Epoch 17/30
32/32 [==============================] - 121s 4s/step - loss: 0.6765 - mse: 0.0063 - accuracy: 0.8045
Epoch 18/30
32/32 [==============================] - 117s 4s/step - loss: 0.6738 - mse: 0.0062 - accuracy: 0.8095
Epoch 19/30
32/32 [==============================] - 121s 4s/step - loss: 0.5028 - mse: 0.0048 - accuracy: 0.8520
Epoch 20/30
32/32 [==============================] - 131s 4s/step - loss: 0.3758 - mse: 0.0037 - accuracy: 0.8810
Epoch 21/30
32/32 [==============================] - 132s 4s/step - loss: 0.3741 - mse: 0.0036 - accuracy: 0.8885
Epoch 22/30
32/32 [==============================] - 128s 4s/step - loss: 0.2951 - mse: 0.0028 - accuracy: 0.9155
Epoch 23/30
32/32 [==============================] - 126s 4s/step - loss: 0.2647 - mse: 0.0026 - accuracy: 0.9170
Epoch 24/30
32/32 [==============================] - 122s 4s/step - loss: 0.2661 - mse: 0.0025 - accuracy: 0.9243
Epoch 25/30
32/32 [==============================] - 124s 4s/step - loss: 0.2080 - mse: 0.0021 - accuracy: 0.9370
Epoch 26/30
32/32 [==============================] - 126s 4s/step - loss: 0.1995 - mse: 0.0020 - accuracy: 0.9400
Epoch 27/30
32/32 [==============================] - 128s 4s/step - loss: 0.2182 - mse: 0.0020 - accuracy: 0.9417
Epoch 28/30
32/32 [==============================] - 1590s 50s/step - loss: 0.1874 - mse: 0.0017 - accuracy: 0.9477
Epoch 29/30
32/32 [==============================] - 134s 4s/step - loss: 0.1802 - mse: 0.0017 - accuracy: 0.9507
Epoch 30/30
32/32 [==============================] - 134s 4s/step - loss: 0.2488 - mse: 0.0021 - accuracy: 0.9390
<tensorflow.python.keras.callbacks.History at 0x18f4a3e0888>
测试
在测试集上进行测试
alexNet.evaluate(batchTest)
结果如下:
9/9 [==============================] - 7s 803ms/step - loss: 7.0082 - mse: 0.0263 - accuracy: 0.2613
总结
可以看到最终模型训练产生过拟合,原因可能是数据集大小不够,没有进行论文中所述的数据增强方法(Data Augment)进行增强。后续对数据进行修改之后会将进一步修改的结果进行总结。