【Keras学习笔记】8:使用Dropout和正则化项抑制过拟合

读取数据和预处理

import keras
from keras import layers
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
%matplotlib inline
Using TensorFlow backend.
data = pd.read_csv("./data/credit-a.csv", header=None)
# Shuffle数据
index = np.random.permutation(len(data))
data = data.iloc[index ,:]
# 划分数据
x = data.iloc[:, 0:-1]
y = data.iloc[:, -1]
x.shape, y.shape
y = y.replace(-1, 0)
x.shape, y.shape
((653, 15), (653,))
k = int(len(x)*0.75)
x_train = x[:k]
x_test = x[k:]
y_train = y[:k]
y_test = y[k:]
x_train.shape, x_test.shape, y_train.shape, y_test.shape
((489, 15), (164, 15), (489,), (164,))

使用Dropout抑制过拟合

model = keras.Sequential()
model.add(layers.Dense(128, input_dim=15, activation='relu'))
model.add(layers.Dropout(0.5)) # 保持连接的概率是0.5
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dropout(0.5)) # 保持连接的概率是0.5
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dropout(0.5)) # 保持连接的概率是0.5
model.add(layers.Dense(1, activation='sigmoid'))
WARNING:tensorflow:From E:\MyProgram\Anaconda\envs\krs\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From E:\MyProgram\Anaconda\envs\krs\lib\site-packages\keras\backend\tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 128)               2048      
_________________________________________________________________
dropout_1 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 128)               16512     
_________________________________________________________________
dropout_2 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 128)               16512     
_________________________________________________________________
dropout_3 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 1)                 129       
=================================================================
Total params: 35,201
Trainable params: 35,201
Non-trainable params: 0
_________________________________________________________________
model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['acc']
)
history = model.fit(x_train, y_train, epochs=1000, validation_data=(x_test, y_test), verbose=0)
WARNING:tensorflow:From E:\MyProgram\Anaconda\envs\krs\lib\site-packages\tensorflow\python\ops\math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.

上节没用Dropout时候评估值是训练集上[0.05114140091863878, 0.9754601228212774],测试集上[1.233076281663848, 0.8658536585365854],现在比较一下。

model.evaluate(x_train, y_train)
489/489 [==============================] - 0s 33us/step





[0.2712044728070437, 0.8793456033938746]
model.evaluate(x_test, y_test)
164/164 [==============================] - 0s 43us/step





[0.50074572170653, 0.8170731707317073]
plt.plot(history.epoch, history.history.get('val_acc'), c='g', label='validation acc')
plt.plot(history.epoch, history.history.get('acc'), c='b', label='train acc')
plt.legend()
<matplotlib.legend.Legend at 0x1469ccf8>

在这里插入图片描述

可以看到训练集和测试集上的ACC走得比较近,成功抑制了过拟合。如果在测试集上的表现都不如之前没有使用Dropout抑制过拟合的情况,那可能是因为训练次数还不够(可以理解加了Dropout之后因为随机断开了一些神经元,所以同样的训练epoch其训练强度肯定是变少了的)。

添加正则化项抑制过拟合

L1:
l o s s = λ ⋅ s u m ( ∣ w i ∣ ) + o l d l o s s loss = \lambda \cdot sum(|w_{i}|)+oldloss loss=λsum(wi)+oldloss

L2:
l o s s = λ ⋅ s u m ( w i 2 ) + o l d l o s s loss = \lambda \cdot sum(w_{i}^2)+oldloss loss=λsum(wi2)+oldloss

from keras import regularizers
model = keras.Sequential()
# 在layer上添加L2正则的参数,则这一层中的参数将被计算到正则化项加到loss里,这里为其设置Weight Decay
model.add(layers.Dense(128, kernel_regularizer=regularizers.l2(0.005), input_dim=15, activation='relu'))
model.add(layers.Dense(128, kernel_regularizer=regularizers.l2(0.005), activation='relu'))
model.add(layers.Dense(128, kernel_regularizer=regularizers.l2(0.005), activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['acc']
)
history = model.fit(x_train, y_train, epochs=1000, validation_data=(x_test, y_test), verbose=0)
model.evaluate(x_train, y_train)
489/489 [==============================] - 0s 33us/step





[0.3962598642932369, 0.8404907977897941]
model.evaluate(x_test, y_test)
164/164 [==============================] - 0s 43us/step





[0.7071414546268743, 0.6951219512195121]
plt.plot(history.epoch, history.history.get('val_acc'), c='g', label='validation acc')
plt.plot(history.epoch, history.history.get('acc'), c='b', label='train acc')
plt.legend()
<matplotlib.legend.Legend at 0x18208208>

在这里插入图片描述
Performance不好可能是训练次数不够,或者网络的超参数选择的不好。

目录列表: 2dplanes.arff abalone.arff ailerons.arff Amazon_initial_50_30_10000.arff anneal.arff anneal.ORIG.arff arrhythmia.arff audiology.arff australian.arff auto93.arff autoHorse.arff autoMpg.arff autoPrice.arff autos.arff auto_price.arff balance-scale.arff bank.arff bank32nh.arff bank8FM.arff baskball.arff bodyfat.arff bolts.arff breast-cancer.arff breast-w.arff breastTumor.arff bridges_version1.arff bridges_version2.arff cal_housing.arff car.arff cholesterol.arff cleveland.arff cloud.arff cmc.arff colic.arff colic.ORIG.arff contact-lenses.arff cpu.arff cpu.with.vendor.arff cpu_act.arff cpu_small.arff credit-a.arff credit-g.arff cylinder-bands.arff delta_ailerons.arff delta_elevators.arff dermatology.arff detroit.arff diabetes.arff diabetes_numeric.arff echoMonths.arff ecoli.arff elevators.arff elusage.arff eucalyptus.arff eye_movements.arff fishcatch.arff flags.arff fried.arff fruitfly.arff gascons.arff glass.arff grub-damage.arff heart-c.arff heart-h.arff heart-statlog.arff hepatitis.arff house_16H.arff house_8L.arff housing.arff hungarian.arff hypothyroid.arff ionosphere.arff iris.2D.arff iris.arff kdd_coil_test-1.arff kdd_coil_test-2.arff kdd_coil_test-3.arff kdd_coil_test-4.arff kdd_coil_test-5.arff kdd_coil_test-6.arff kdd_coil_test-7.arff kdd_coil_train-1.arff kdd_coil_train-3.arff kdd_coil_train-4.arff kdd_coil_train-5.arff kdd_coil_train-6.arff kdd_coil_train-7.arff kdd_el_nino-small.arff kdd_internet_usage.arff kdd_ipums_la_97-small.arff kdd_ipums_la_98-small.arff kdd_ipums_la_99-small.arff kdd_JapaneseVowels_test.arff kdd_JapaneseVowels_train.arff kdd_synthetic_control.arff kdd_SyskillWebert-Bands.arff kdd_SyskillWebert-BioMedical.arff kdd_SyskillWebert-Goats.arff kdd_SyskillWebert-Sheep.arff kdd_UNIX_user_data.arff kin8nm.arff kr-vs-kp.arff labor.arff landsat_test.arff landsat_train.arff letter.arff liver-disorders.arff longley.arff lowbwt.arff lung-cancer.arff lymph.arff machine_cpu.arff mbagrade.arff meta.arff mfeat-factors.arff mfeat-fourier.arff mfeat-karhunen.arff mfeat-morphological.arff mfeat-pixel.arff mfeat-zernike.arff molecular-biology_promoters.arff monks-problems-1_test.arff monks-problems-1_train.arff monks-problems-2_test.arff monks-problems-2_train.arff monks-problems-3_test.arff monks-problems-3_train.arff mushroom.arff mv.arff nursery.arff optdigits.arff page-blocks.arff pasture.arff pbc.arff pendigits.arff pharynx.arff pol.arff pollution.arff postoperative-patient-data.arff primary-tumor.arff puma32H.arff puma8NH.arff pwLinear.arff pyrim.arff quake.arff ReutersCorn-test.arff ReutersCorn-train.arff ReutersGrain-test.arff ReutersGrain-train.arff schlvote.arff segment-challenge.arff segment-test.arff segment.arff sensory.arff servo.arff sick.arff sleep.arff solar-flare_1.arff solar-flare_2.arff sonar.arff soybean.arff spambase.arff spectf_test.arff spectf_train.arff spectrometer.arff spect_test.arff spect_train.arff splice.arff sponge.arff squash-stored.arff squash-unstored.arff stock.arff strike.arff supermarket.arff triazines.arff unbalanced.arff vehicle.arff veteran.arff vineyard.arff vote.arff vowel.arff water-treatment.arff waveform-5000.arff weather.nominal.arff weather.numeric.arff white-clover.arff wine.arff wisconsin.arff zoo.arff
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值