论文《ImageNet Classification with Deep Convolutional Neural Networks》阅读及AlexNet的Tensorflow2复现

论文《ImageNet Classification with Deep Convolutional Neural Networks》阅读及AlexNet的Tensorflow2复现

论文亮点

这周阅读了Hinton和他的学生Alex的经典论文《ImageNet Classification with Deep Convolutional Neural Networks》,总结亮点包括如下几部分:

  1. 提出了ReLU激活函数(ReLU, Rectified Linear Unit, 修正线性单元), R e L U ( x ) = m a x ( 0 , x ) ReLU(x)=max(0, x) ReLU(x)=max(0,x),由于该激活函数在趋近正无穷是也趋近正无穷,因此不会存在梯度消失的现象,可以加快训练过程。Alex将该函数称为非饱和函数(Non-Saturating Function)
  2. 双GPU加速,Alex首次使用双通道GPU,并在第三层进行全部连接,而剩余层则各自单独计算。
  3. 局部响应归一化(LRN, local Response Normalization)层的提出,对多个channel同一位置的数据进行归一化
  4. Pooling重叠,在该文中,AlexNet全部使用3*3的最大池化窗口,步长为2
  5. 数据增强方法,为防止过拟合,1)Alex首先使用翻转,将数据集扩大了2倍。然后用224224大小的窗口对256256的原始图片进行裁切,裁切过程中步长为1,形成32*32=1024个子图,从而再次扩大1024倍。因此将数据集扩大了2048倍。2)使用PCA对像素进行分析,改变图片的光照强度,进一步扩展数据集。
  6. DropOut的提出,通过使用DropOut方法,对全连接层的单元进行随机置零,从而避免了模型过拟合。

模型原理

alexNet.jpg
AlexNet原理图如上,使用的双GPU进行训练。由于没有这么多资源,我完成了一半的模型,大小如图中所示。具体的每一层的参数,如下图所示。
AlexModel.jpg

代码复现

import os
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import tensorflow as tf
import time

定义卷积层

class MyConv2DLayer(tf.keras.layers.Layer):
    '''
        自定义卷积层
        * `__init__()`: Save configuration in member variables
        * `build()`: Called once from `__call__`, when we know the shapes of inputs
          and `dtype`. Should have the calls to `add_weight()`, and then
          call the super's `build()` (which sets `self.built = True`, which is
          nice in case the user wants to call `build()` manually before the
          first `__call__`).
        * `call()`: Called in `__call__` after making sure `build()` has been called
          once. Should actually perform the logic of applying the layer to the
          input tensors (which should be passed in as the first argument).
    '''
    def __init__(self):
        super().__init__()
        
    
    def build(self, input_shape):
        # 卷积核大小
        self.w = self.add_weight(name='weights', shape=[5, 5, input_shape[-1], 128])
    
    def call(self, input):
        '''
        input's default dimension order is [batch, height, width, channels]
        filter's order is [filter_height, filter_width, in_channels, out_channels]
        '''
        output = tf.nn.conv2d(input=input, filters=self.w, strides=1, padding=[[0, 0], [2, 2], [2, 2], [0, 0]])
        return output

定义LRN层

a x , y i a_{x, y}^{i} ax,yi表示在(x, y)这个位置上,以kernel i来计算后,经过激励后的输出。经过LRN层后,作为下一层输入的数据 b x , y i b_{x, y}^{i} bx,yi变为:
b x , y i = a x , y i ( k + α ∑ j = m a x ( 0 , i − n / 2 ) m i n ( N − 1 , i + n / 2 ) ( a x , y i ) 2 ) β b_{x, y}^{i} = \frac{a_{x, y}^{i}}{(k+\alpha\sum\limits_{j=max(0, i-n/2)}^{min(N-1, i+n/2)}{(a_{x, y}^{i})^2)^\beta}} bx,yi=(k+αj=max(0,in/2)min(N1,i+n/2)(ax,yi)2)βax,yi
其中,N表示改成的feature map(kernel)总数,n表示该feature map左右两边各n/2大小的窗口做平均。其余的几个参数如下: k = 2 k=2 k=2 n = 5 n=5 n=5 α = 1 0 − 4 \alpha=10^{-4} α=104 β = 0.75 \beta=0.75 β=0.75

class myLRNLayer(tf.keras.layers.Layer):
    '''
        自定义LRN归一化层
    '''
    def __init__(self, alpha, beta, k, n):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.k = k
        self.n = n
    
    def build(self, input_shape):
        print("SHAPE: ", input_shape)
    
    def call(self, input):
        '''
            input's dimension order is [batch, height, width, channel]
            具体操作为沿着channel的方向向前、向后取n/2个channel进行平方加和,乘以apha、beta等系数之后进行归一化,此处略,具体操作使用BN进行替换
        '''
        return input / (self.k + self.alpha * input) ** self.beta
# x = tf.ones([3, 5]) * 2
# m = myLRNLayer(1, 2, 1, 2)
# m(x)

数据预处理

使用Caltech-256数据集进行实验,由于整个数据集1G多,笔记本跑不起来,硬盘中原始数据的存储格式如下所示:
dataset
我们使用其中的1/6作为数据集进行实验,其中一共包含45个分类,共5102张图片。将其中4000张图片作为训练集,剩余的作为测试集进行实验。

PATH = 'D:\\迅雷下载\\Caltech-256\\data\\256_ObjectCategories'
dirNames = os.listdir(PATH)
labels = [int(dirName.split('.')[0]) for dirName in dirNames]
picDirs = [os.path.join(PATH, dirName) for dirName in dirNames]
sum = 0
for iter, typeDir in enumerate(picDirs):
    tempSum = 0
    for picName in os.listdir(typeDir):
        sum += 1
        tempSum +=1
    print(tempSum, end='\t')
print("\nTOTAL NUMBER OF THE DATASET IS: %d" % sum)

输出如下:

98	97	151	127	148	90	106	232	102	94	278	216	98	86	122	91	104	101	124	83	142	97	110	112	114	106	100	110	103	104	90	101	102	100	87	106	120	110	85	124	87	87	124	121	85	133	94	103	106	97	114	85	82	118	98	102	106	93	83	87	102	83	122	131	101	83	83	110	99	84	99	118	100	115	83	84	92	90	99	116	95	81	95	84	112	80	93	98	110	212	95	201	112	104	86	285	89	100	80	93	138	88	111	97	270	87	89	85	156	85	84	84	116	120	88	107	121	108	87	130	82	103	111	91	101	242	128	105	190	91	92	190	136	119	93	93	156	192	86	89	117	107	130	82	798	82	202	174	103	111	109	120	93	103	92	96	105	149	209	83	103	91	90	101	88	92	86	140	92	102	84	99	106	84	83	110	96	98	80	102	100	120	100	84	103	81	95	88	119	112	111	112	174	112	87	104	100	109	105	100	81	97	91	80	87	98	115	109	102	111	95	136	101	139	84	98	105	81	84	94	103	91	80	110	90	99	147	95	95	94	112	358	100	122	114	97	90	97	84	201	95	93	90	91	91	101	92	84	100	96	800	116	435	95	103	108	827	
TOTAL NUMBER OF THE DATASET IS: 30608

读取数据

# 数据集大小
LEN = int(np.ceil(sum / 6))
# 训练集大小, 根据数据集大小确定
TRAIN_SIZE = 4000
# 分类数目,决定softmax最终的输出层节点数
CATEGORY_NUM = 45

imgArray = np.zeros(shape = [LEN, 227, 227, 3], dtype=np.int16)
labelArray = np.empty(shape = [LEN], dtype=np.int32)

index = 0
for iter, typeDir in enumerate(picDirs):
    print("Processing Label: %s" % typeDir)
    for picName in os.listdir(typeDir):
        if index >= LEN:
            break
        picPath = os.path.join(typeDir, picName)
        
        img = plt.imread(picPath)
        imgg = cv2.resize(img, (227, 227))
        label = picName.split('.')[0].split('_')[0]
        label = int(label)
        try:
            imgArray[index, :] = imgg
        except Exception as e:
            print("Error Occured. Probably this picture is a grayscale one...")
            imgg = np.stack((imgg, imgg, imgg), axis=-1)
            imgArray[index, :] = imgg
        labelArray[index] = label
        index += 1
    if index >= LEN:
        break
print("Data Preprocessing Done.\nTotal Category: %3d\tTotal Number: %3d" % (iter + 1, index))

def preProcess(x, y):
    '''
        预处理,x转换为[0, 1)范围的小数
        y处理为one-hot编码
    '''
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.one_hot(y, depth=CATEGORY_NUM)
    return x, y

输出如下:

Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\001.ak47
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\002.american-flag
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\003.backpack
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\004.baseball-bat
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\005.baseball-glove
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\006.basketball-hoop
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\007.bat
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\008.bathtub
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\009.bear
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\010.beer-mug
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\011.billiards
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\012.binoculars
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\013.birdbath
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\014.blimp
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\015.bonsai-101
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\016.boom-box
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\017.bowling-ball
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\018.bowling-pin
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\019.boxing-glove
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\020.brain-101
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\021.breadmaker
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\022.buddha-101
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\023.bulldozer
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\024.butterfly
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\025.cactus
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\026.cake
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\027.calculator
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\028.camel
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\029.cannon
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\030.canoe
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\031.car-tire
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\032.cartman
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\033.cd
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\034.centipede
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\035.cereal-box
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\036.chandelier-101
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\037.chess-board
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\038.chimp
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\039.chopsticks
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\040.cockroach
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\041.coffee-mug
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\042.coffin
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\043.coin
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\044.comet
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Error Occured. Probably this picture is a grayscale one...
Processing Label: D:\迅雷下载\Caltech-256\data\256_ObjectCategories\045.computer-keyboard
Data Preprocessing Done.
Total Category:  45	Total Number: 5102

对数据进行shuffle和batch操作:

# 设置同一个种子可以保证分割后data和label依旧一一对应
randomSeed = time.time()
tf.random.set_seed(randomSeed)
imgShuffled = tf.random.shuffle(imgArray)
print("Image Shuffling Done...")
tf.random.set_seed(randomSeed)
labelShuffled = tf.random.shuffle(labelArray)
print("Label Shuffling Done...")
imgShuffled = tf.cast(imgShuffled, dtype=tf.int16)

trainData, testData = tf.split(imgShuffled, [TRAIN_SIZE, LEN - TRAIN_SIZE], axis=0)
trainLabel, testLabel = tf.split(labelShuffled, [TRAIN_SIZE, LEN - TRAIN_SIZE], axis=0)

# batch处理
batchTrain = tf.data.Dataset.from_tensor_slices((trainData, trainLabel)).batch(128)
batchTest = tf.data.Dataset.from_tensor_slices((testData, testLabel)).batch(128)

batchTrain = batchTrain.map(preProcess)
batchTest = batchTest.map(preProcess)

定义模型

# 用于辅助确定模型输出shape
# x = tf.random.normal([1, 227, 227, 3])
alexNet = tf.keras.Sequential([
    # 第一层:[卷积层-池化层-LRN层]
    # 卷积核大小为48@11*11, 步长为4
    tf.keras.layers.Conv2D(48, 11, 4, activation='relu'),
    # 最大池化层大小为3*3, 步长为2
    tf.keras.layers.MaxPooling2D((3, 3), 2),
    # 归一化层
    tf.keras.layers.BatchNormalization(),
    
    # 第二层:[卷积层-池化层-LRN层]
    # 卷积核大小为128@5*5, 步长为1, 同时需要指定padding,因此自己实现卷积层
    MyConv2DLayer(),
    # 最大池化层大小为3*3, 步长为2
    tf.keras.layers.MaxPooling2D((3, 3), 2, padding='valid'),
    tf.keras.layers.ReLU(),
    # 归一化层
    tf.keras.layers.BatchNormalization(),
    
    #第三层:[只包含卷积层]
    tf.keras.layers.Conv2D(192, kernel_size=3, strides=1, padding='same', activation='relu'),
    
    #第四层:[只包含卷积层]
    tf.keras.layers.Conv2D(192, kernel_size=3, strides=1, padding='same', activation='relu'),
    
    # 第五层:[]
    tf.keras.layers.Conv2D(128, kernel_size=3, strides=1, padding='same', activation='relu'),
    tf.keras.layers.ReLU(),
    tf.keras.layers.MaxPooling2D((3, 3), 2),
    
    # 全连接层
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(2048, activation='relu'),
    tf.keras.layers.Dropout(0.25),
    
    # 全连接层
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(2048, activation='relu'),
    tf.keras.layers.Dropout(0.25),
    
    # 输出层
    tf.keras.layers.Dense(CATEGORY_NUM),
    tf.keras.layers.Softmax()
])
alexNet.build(input_shape=(None, 227, 227, 3))
alexNet.summary()

模型summary输出如下:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              multiple                  17472     
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
batch_normalization (BatchNo multiple                  192       
_________________________________________________________________
my_conv2d_layer (MyConv2DLay multiple                  153600    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 multiple                  0         
_________________________________________________________________
re_lu (ReLU)                 multiple                  0         
_________________________________________________________________
batch_normalization_1 (Batch multiple                  512       
_________________________________________________________________
conv2d_1 (Conv2D)            multiple                  221376    
_________________________________________________________________
conv2d_2 (Conv2D)            multiple                  331968    
_________________________________________________________________
conv2d_3 (Conv2D)            multiple                  221312    
_________________________________________________________________
re_lu_1 (ReLU)               multiple                  0         
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 multiple                  0         
_________________________________________________________________
flatten (Flatten)            multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  9439232   
_________________________________________________________________
dropout (Dropout)            multiple                  0         
_________________________________________________________________
flatten_1 (Flatten)          multiple                  0         
_________________________________________________________________
dense_1 (Dense)              multiple                  4196352   
_________________________________________________________________
dropout_1 (Dropout)          multiple                  0         
_________________________________________________________________
dense_2 (Dense)              multiple                  92205     
_________________________________________________________________
softmax (Softmax)            multiple                  0         
=================================================================
Total params: 14,674,221
Trainable params: 14,673,869
Non-trainable params: 352
_________________________________________________________________

模型训练

训练模型,共30个epoch

alexNet.compile(optimizer='adam',
                loss=tf.keras.losses.CategoricalCrossentropy(),
                metrics=['mse', 'accuracy'])
alexNet.fit(batchTrain, epochs=30)

模型训练的输出metric:

Epoch 1/30
32/32 [==============================] - 127s 4s/step - loss: 3.8214 - mse: 0.0218 - accuracy: 0.0698
Epoch 2/30
32/32 [==============================] - 117s 4s/step - loss: 3.4564 - mse: 0.0211 - accuracy: 0.1328
Epoch 3/30
32/32 [==============================] - 117s 4s/step - loss: 3.1594 - mse: 0.0203 - accuracy: 0.1898
Epoch 4/30
32/32 [==============================] - 117s 4s/step - loss: 2.9747 - mse: 0.0197 - accuracy: 0.2265
Epoch 5/30
32/32 [==============================] - 125s 4s/step - loss: 2.7330 - mse: 0.0188 - accuracy: 0.2743
Epoch 6/30
32/32 [==============================] - 118s 4s/step - loss: 2.4830 - mse: 0.0177 - accuracy: 0.3315
Epoch 7/30
32/32 [==============================] - 118s 4s/step - loss: 2.2459 - mse: 0.0167 - accuracy: 0.3842
Epoch 8/30
32/32 [==============================] - 118s 4s/step - loss: 2.0411 - mse: 0.0158 - accuracy: 0.4290
Epoch 9/30
32/32 [==============================] - 145s 5s/step - loss: 1.8709 - mse: 0.0149 - accuracy: 0.4737
Epoch 10/30
32/32 [==============================] - 160s 5s/step - loss: 1.6553 - mse: 0.0135 - accuracy: 0.5280
Epoch 11/30
32/32 [==============================] - 144s 5s/step - loss: 1.4724 - mse: 0.0122 - accuracy: 0.5840
Epoch 12/30
32/32 [==============================] - 153s 5s/step - loss: 1.3277 - mse: 0.0114 - accuracy: 0.6168
Epoch 13/30
32/32 [==============================] - 160s 5s/step - loss: 0.9913 - mse: 0.0089 - accuracy: 0.7117
Epoch 14/30
32/32 [==============================] - 136s 4s/step - loss: 0.9707 - mse: 0.0085 - accuracy: 0.7258
Epoch 15/30
32/32 [==============================] - 120s 4s/step - loss: 0.8679 - mse: 0.0078 - accuracy: 0.7490
Epoch 16/30
32/32 [==============================] - 116s 4s/step - loss: 0.7359 - mse: 0.0068 - accuracy: 0.7862
Epoch 17/30
32/32 [==============================] - 121s 4s/step - loss: 0.6765 - mse: 0.0063 - accuracy: 0.8045
Epoch 18/30
32/32 [==============================] - 117s 4s/step - loss: 0.6738 - mse: 0.0062 - accuracy: 0.8095
Epoch 19/30
32/32 [==============================] - 121s 4s/step - loss: 0.5028 - mse: 0.0048 - accuracy: 0.8520
Epoch 20/30
32/32 [==============================] - 131s 4s/step - loss: 0.3758 - mse: 0.0037 - accuracy: 0.8810
Epoch 21/30
32/32 [==============================] - 132s 4s/step - loss: 0.3741 - mse: 0.0036 - accuracy: 0.8885
Epoch 22/30
32/32 [==============================] - 128s 4s/step - loss: 0.2951 - mse: 0.0028 - accuracy: 0.9155
Epoch 23/30
32/32 [==============================] - 126s 4s/step - loss: 0.2647 - mse: 0.0026 - accuracy: 0.9170
Epoch 24/30
32/32 [==============================] - 122s 4s/step - loss: 0.2661 - mse: 0.0025 - accuracy: 0.9243
Epoch 25/30
32/32 [==============================] - 124s 4s/step - loss: 0.2080 - mse: 0.0021 - accuracy: 0.9370
Epoch 26/30
32/32 [==============================] - 126s 4s/step - loss: 0.1995 - mse: 0.0020 - accuracy: 0.9400
Epoch 27/30
32/32 [==============================] - 128s 4s/step - loss: 0.2182 - mse: 0.0020 - accuracy: 0.9417
Epoch 28/30
32/32 [==============================] - 1590s 50s/step - loss: 0.1874 - mse: 0.0017 - accuracy: 0.9477
Epoch 29/30
32/32 [==============================] - 134s 4s/step - loss: 0.1802 - mse: 0.0017 - accuracy: 0.9507
Epoch 30/30
32/32 [==============================] - 134s 4s/step - loss: 0.2488 - mse: 0.0021 - accuracy: 0.9390
<tensorflow.python.keras.callbacks.History at 0x18f4a3e0888>

测试

在测试集上进行测试

alexNet.evaluate(batchTest)

结果如下:

9/9 [==============================] - 7s 803ms/step - loss: 7.0082 - mse: 0.0263 - accuracy: 0.2613

总结

可以看到最终模型训练产生过拟合,原因可能是数据集大小不够,没有进行论文中所述的数据增强方法(Data Augment)进行增强。后续对数据进行修改之后会将进一步修改的结果进行总结。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值