tensorflow2图像分类实战激活函数selu

1. 内容回顾

上一篇文章主要从理论方面讲述了激活函数的相关概念和一些注意事项。本篇文章将会主要介绍tensorflow2.2中如何进行激活函数的设置和模型的训练。

2. 代码实战讲解

我们仍旧以图片分类的例子进行实战,可以参看之前的《TensorFlow2 Fashion-MNIST图像分类》部分。
之前提到selu激活函数自带了很多的优势,比如自带归一化功能,训练速度比较快。所以我们将会使用该激活函数进行试验,selu激活函数的这些优势,让它被广泛使用。

模型代码部分如下:

# tf.keras.models.Sequential()

model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28,28]))
# 直接通过循环的方式添加全连接层
# 修改激活函数,使用selu,该激活函数自带归一化功能,在一定程度上可以缓解梯度消失问题,索引训练曲线刚开始不会出现平滑问题
for _ in range(20):
    model.add(keras.layers.Dense(100, activation="selu"))
    
model.add(keras.layers.Dense(10, activation="softmax"))

model.compile(loss="sparse_categorical_crossentropy", optimizer='sgd', metrics=["accuracy"])

数据的加载和处理,参照之前的文章,此处只进行激活函数部分的讲解。

这里直接通过循环的方式构造了一个深层的神经网络,每一层使用selu激活函数。模型最后一层为分类全连接层,激活函数使用softmax。

各层的参数结构如下:

model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 100)               78500     
_________________________________________________________________
dense_1 (Dense)              (None, 100)               10100     
_________________________________________________________________
dense_2 (Dense)              (None, 100)               10100     
_________________________________________________________________
dense_3 (Dense)              (None, 100)               10100     
_________________________________________________________________
dense_4 (Dense)              (None, 100)               10100     
_________________________________________________________________
dense_5 (Dense)              (None, 100)               10100     
_________________________________________________________________
dense_6 (Dense)              (None, 100)               10100     
_________________________________________________________________
dense_7 (Dense)              (None, 100)               10100     
_________________________________________________________________
dense_8 (Dense)              (None, 100)               10100     
_________________________________________________________________
dense_9 (Dense)              (None, 100)               10100     
_________________________________________________________________
dense_10 (Dense)             (None, 100)               10100     
_________________________________________________________________
dense_11 (Dense)             (None, 100)               10100     
_________________________________________________________________
dense_12 (Dense)             (None, 100)               10100     
_________________________________________________________________
dense_13 (Dense)             (None, 100)               10100     
_________________________________________________________________
dense_14 (Dense)             (None, 100)               10100     
_________________________________________________________________
dense_15 (Dense)             (None, 100)               10100     
_________________________________________________________________
dense_16 (Dense)             (None, 100)               10100     
_________________________________________________________________
dense_17 (Dense)             (None, 100)               10100     
_________________________________________________________________
dense_18 (Dense)             (None, 100)               10100     
_________________________________________________________________
dense_19 (Dense)             (None, 100)               10100     
_________________________________________________________________
dense_20 (Dense)             (None, 10)                1010      
=================================================================
Total params: 271,410
Trainable params: 271,410
Non-trainable params: 0

训练部分代码:

# Tensorboard, earlystopping, modelcheckpoint
logdir = './dnn-selu-callbacks'
if not os.path.exists(logdir):
    os.mkdir(logdir)
output_model_file = os.path.join(logdir, "fashion_mnist_model.h5")

callbacks = [
    keras.callbacks.TensorBoard(logdir),
    keras.callbacks.ModelCheckpoint(output_model_file, save_best_only=True),
    keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),
]

history = model.fit(x_train_scaled, y_train, epochs=10,
                   validation_data=(x_valid_scaled, y_valid),
                   callbacks=callbacks)

训练结果如下:

Epoch 1/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.5535 - accuracy: 0.7996 - val_loss: 0.4365 - val_accuracy: 0.8408
Epoch 2/10
1719/1719 [==============================] - 11s 7ms/step - loss: 0.4072 - accuracy: 0.8507 - val_loss: 0.4379 - val_accuracy: 0.8402
Epoch 3/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3645 - accuracy: 0.8663 - val_loss: 0.3726 - val_accuracy: 0.8618
Epoch 4/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3357 - accuracy: 0.8748 - val_loss: 0.3553 - val_accuracy: 0.8760
Epoch 5/10
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3139 - accuracy: 0.8838 - val_loss: 0.3295 - val_accuracy: 0.8844
Epoch 6/10
1719/1719 [==============================] - 11s 7ms/step - loss: 0.2977 - accuracy: 0.8881 - val_loss: 0.3457 - val_accuracy: 0.8798
Epoch 7/10
1719/1719 [==============================] - 11s 7ms/step - loss: 0.2843 - accuracy: 0.8952 - val_loss: 0.3292 - val_accuracy: 0.8838
Epoch 8/10
1719/1719 [==============================] - 11s 7ms/step - loss: 0.2735 - accuracy: 0.8972 - val_loss: 0.3182 - val_accuracy: 0.8840
Epoch 9/10
1719/1719 [==============================] - 11s 7ms/step - loss: 0.2596 - accuracy: 0.9029 - val_loss: 0.3302 - val_accuracy: 0.8760
Epoch 10/10
1719/1719 [==============================] - 11s 6ms/step - loss: 0.2517 - accuracy: 0.9051 - val_loss: 0.3313 - val_accuracy: 0.8812

接下来我们再添加dropout层,进行对比,直接上代码:

# tf.keras.models.Sequential()

model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28,28]))
# 直接通过循环的方式添加全连接层
# 修改激活函数,使用selu,该激活函数自带归一化功能,在一定程度上可以缓解梯度消失问题,索引训练曲线刚开始不会出现平滑问题
for _ in range(20):
    model.add(keras.layers.Dense(100, activation="selu"))
model.add(keras.layers.AlphaDropout(rate=0.5))
model.add(keras.layers.Dense(10, activation="softmax"))

model.compile(loss="sparse_categorical_crossentropy", optimizer='sgd', metrics=["accuracy"])

与之前的唯一区别就是添加了dropout层,并且对于selu激活函数,tensorflow2中dropout要使用AlphaDropout。
模型训练结果如下:

Epoch 1/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.6936 - accuracy: 0.7647 - val_loss: 0.6364 - val_accuracy: 0.8528
Epoch 2/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.4541 - accuracy: 0.8436 - val_loss: 0.6049 - val_accuracy: 0.8614
Epoch 3/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.4017 - accuracy: 0.8594 - val_loss: 0.5722 - val_accuracy: 0.8702
Epoch 4/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3731 - accuracy: 0.8692 - val_loss: 0.5503 - val_accuracy: 0.8586
Epoch 5/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3514 - accuracy: 0.8767 - val_loss: 0.5628 - val_accuracy: 0.8684
Epoch 6/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3318 - accuracy: 0.8812 - val_loss: 0.7251 - val_accuracy: 0.8326
Epoch 7/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3156 - accuracy: 0.8875 - val_loss: 0.5052 - val_accuracy: 0.8822
Epoch 8/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3056 - accuracy: 0.8891 - val_loss: 0.5531 - val_accuracy: 0.8782
Epoch 9/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.2919 - accuracy: 0.8951 - val_loss: 0.5647 - val_accuracy: 0.8806
Epoch 10/10
1719/1719 [==============================] - 12s 7ms/step - loss: 0.2871 - accuracy: 0.8963 - val_loss: 0.4928 - val_accuracy: 0.8886

添加dropout与否对于该图像分类数据影响不大,说明模型并没有过拟合。
完整代码关注公众号:【瞧不死的AI】回复‘激活函数’获取。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值