Dropout与过拟合抑制、函数式API

如何添加Dropout层

在网络中添加Dropout层,主要是在隐藏层中使用,依然是使用之前的例子,如下:

# 建立模型
model = tf.keras.Sequential()
# 添加层
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))  # Flatten将二维数据扁平为28*28向量
model.add(tf.keras.layers.Dense(128, activation='relu'))  # 接着使用Dense处理一维数据,128个隐藏单元
model.add(tf.keras.layers.Dropout(0.5)) # 随机丢弃50%的隐藏单元
model.add(tf.keras.layers.Dense(128, activation='relu'))  # 接着使用Dense处理一维数据,128个隐藏单元
model.add(tf.keras.layers.Dropout(0.5)) # 随机丢弃50%的隐藏单元
model.add(tf.keras.layers.Dense(128, activation='relu'))  # 接着使用Dense处理一维数据,128个隐藏单元
model.add(tf.keras.layers.Dropout(0.5)) # 随机丢弃50%的隐藏单元
model.add(tf.keras.layers.Dense(10, activation='softmax'))  # 输出层,使用softmax激活函数,将输出变为10个概率分类

最后可以看到训练集和测试集的损失和正确率的曲线比较可以看出都没有过拟合。但是因为Dropout层数多啦,发现训练集的损失和准确率要比测试集的高和低。

减小网络规模也是抑制过拟合的非常好的方法。
正则化的原理就是控制网络规模,控制参数规模。

函数式API

Sequential()模型就只有一个输入和一个输出,中间的隐藏层都是顺序连接的,结构单一。如果要建立残差网络,还有输入直接连着输出,这就需要使用函数式API。

下面是函数式API代码示例:(依然是使用Fashion_MNIST数据集为例)

# -*- coding: UTF-8 -*-
"""
Author: LGD
FileName: functional_api
DateTime: 2020/11/12 09:23 
SoftWare: PyCharm
"""
import matplotlib.pyplot as plt
from tensorflow import keras

fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# 数据归一化
train_images = train_images / 255.0
test_images = test_images / 255.0

# 使用函数式api建立模型
# 输入层
input = keras.Input(shape=(28, 28))  # [(None, 28, 28)] None代表它是个数维度,任意个
x = keras.layers.Flatten()(input)
# 隐藏层
x = keras.layers.Dense(32, activation='relu')(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(64, activation='relu')(x)
# 输出层
output = keras.layers.Dense(10, activation='softmax')(x)
# 输入参数建立模型
model = keras.Model(inputs=input, outputs=output)
# 查看模型
model.summary()

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    train_images,
    train_labels,
    epochs=30,
    validation_data=(test_images, test_labels)
)

test_loss, test_acc = model.evaluate(test_images, test_labels)

plt.plot(history.epoch, history.history['loss'], 'r', label='loss')
plt.plot(history.epoch, history.history['val_loss'], 'b--', label='val_loss')
plt.legend(title='loss curve of tests and trains')
plt.show()
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Hong_Youth

您的鼓励将是我创作的动力。

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值