VGG16网络结构如下所示:
实战代码
import pathlib
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, Model, Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Dense, Flatten, MaxPool2D, Dropout
import numpy as np
import matplotlib.pylab as plt
import pathlib
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, Model, Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Dense, Flatten, Dropout, MaxPooling2D
import numpy as np
import os
import matplotlib.pylab as plt
import PySide2
dirname = os.path.dirname(PySide2.__file__)
plugin_path = os.path.join(dirname, 'plugins', 'platforms')
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = plugin_path
# 准备数据
data_dir = "F:\danzi\数据"
data_dir = pathlib.Path(data_dir) # 读出的是c1-c10文件夹
# image_count = len(list(data_dir.glob("*/*"))) # 读出的是图像的个数
#
# print("图片的综述为:", image_count)
# 参数设定
batch_size = 8
image_height = 64
image_wijdth = 64
epoch = 50
# 划分数据集
train_dst = keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=None, # 0和1之间的可选浮点数,可保留一部分数据用于验证
seed=123, # 用于shuffle和转换的可选随机种子
image_size=(image_height, image_wijdth),
batch_size=batch_size
)
test_dst = keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=None,
seed=123,
image_size=(image_height, image_wijdth),
batch_size=batch_size
)
# 获取训练集的标签
triain_dst_label = train_dst.class_names
print(triain_dst_label)
AUTOTUME = tf.data.experimental.AUTOTUNE # 根据可用的CPU动态设置并行调用的数量。
train_dst = train_dst.cache().shuffle(1000).prefetch(buffer_size=AUTOTUME)
test_dst = test_dst.cache().shuffle(1000).prefetch(buffer_size=AUTOTUME)
# 模型的定义
class VGG16(Model):
def __init__(self):
super(VGG16, self).__init__()
# SAME表示0填充,VALID表示不填充
"""
2个包含64个卷积核的卷积层
"""
self.c1 = Conv2D(filters=64, kernel_size=(3, 3), padding="SAME")
# 批标准化防止梯度消失,这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。
self.b1 = BatchNormalization()
self.a1 = Activation('relu')
# 64个卷积核
self.c2 = Conv2D(filters=64, kernel_size=(3, 3), padding="SAME")
self.b2 = BatchNormalization()
self.a2 = Activation('relu')
self.p1 = MaxPool2D(pool_size=(2, 2), strides=2)
self.d1 = Dropout(0.2)
"""
2个包含128个卷积核的卷积层
"""
self.c3 = Conv2D(filters=128, kernel_size=(3, 3), padding="SAME")
self.b3 = BatchNormalization()
self.a3 = Activation('relu')
self.c4 = Conv2D(filters=128, kernel_size=(3, 3), padding="SAME")
self.b4 = BatchNormalization()
self.a4 = Activation('relu')
self.p2 = MaxPool2D(pool_size=(2, 2), strides=2)
self.d2 = Dropout(0.2)
"""
3个包含256个卷积核的卷积层
"""
self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding="SAME")
self.b5 = BatchNormalization()
self.a5 = Activation('relu')
self.c6 = Conv2D(filters=256, kernel_size=(3, 3), padding="SAME")
self.b6 = BatchNormalization()
self.a6 = Activation('relu')
self.c7 = Conv2D(filters=256, kernel_size=(3, 3), padding="SAME")
self.b7 = BatchNormalization()
self.a7 = Activation('relu')
self.p3 = MaxPool2D(pool_size=(2, 2), strides=2)
self.d3 = Dropout(0.2)
"""
6个包含512个卷积核的卷积层
"""
self.c8 = Conv2D(filters=512, kernel_size=(3, 3), padding="SAME")
self.b8 = BatchNormalization()
self.a8 = Activation('relu')
self.c9 = Conv2D(filters=512, kernel_size=(3, 3), padding="SAME")
self.b9 = BatchNormalization()
self.a9 = Activation('relu')
self.c10 = Conv2D(filters=512, kernel_size=(3, 3), padding="SAME")
self.b10 = BatchNormalization()
self.a10 = Activation('relu')
self.p4 = MaxPool2D(pool_size=(2, 2), strides=2)
self.d4 = Dropout(0.2)
self.c11 = Conv2D(filters=512, kernel_size=(3, 3), padding="SAME")
self.b11 = BatchNormalization()
self.a11 = Activation('relu')
self.c12 = Conv2D(filters=512, kernel_size=(3, 3), padding="SAME")
self.b12 = BatchNormalization()
self.a12 = Activation('relu')
self.c13 = Conv2D(filters=512, kernel_size=(3, 3), padding="SAME")
self.b13 = BatchNormalization()
self.a13 = Activation('relu')
self.p5 = MaxPool2D(pool_size=(2, 2), strides=2)
self.d5 = Dropout(0.2)
"""
2个包含4096个神经元的全连接层
"""
# Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小。
self.flatten = Flatten()
# FC层在keras中叫做Dense层,在pytorch中叫Linear层
self.f1 = Dense(4096, activation='relu')
self.d6 = Dropout(0.2)
self.f2 = Dense(4096, activation='relu')
self.d7 = Dropout(0.2)
"""
1个包含11个神经元的全连接层
"""
self.f3 = Dense(11, activation='softmax')
def call(self, x):
x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.c2(x)
x = self.b2(x)
x = self.a2(x)
x = self.p1(x)
x = self.d1(x)
x = self.c3(x)
x = self.b3(x)
x = self.a3(x)
x = self.c4(x)
x = self.b4(x)
x = self.a4(x)
x = self.p2(x)
x = self.d2(x)
x = self.c5(x)
x = self.b5(x)
x = self.a5(x)
x = self.c6(x)
x = self.b6(x)
x = self.a6(x)
x = self.c7(x)
x = self.b7(x)
x = self.a7(x)
x = self.p3(x)
x = self.d3(x)
x = self.c8(x)
x = self.b8(x)
x = self.a8(x)
x = self.c9(x)
x = self.b9(x)
x = self.a9(x)
x = self.c10(x)
x = self.b10(x)
x = self.a10(x)
x = self.p4(x)
x = self.d4(x)
x = self.c11(x)
x = self.b11(x)
x = self.a11(x)
x = self.c12(x)
x = self.b12(x)
x = self.a12(x)
x = self.c13(x)
x = self.b13(x)
x = self.a13(x)
x = self.p5(x)
x = self.d5(x)
x = self.flatten(x)
x = self.f1(x)
x = self.d6(x)
x = self.f2(x)
x = self.d7(x)
y = self.f3(x)
return y
model = VGG16()
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(optimizer=opt, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
history = model.fit(train_dst, validation_data=test_dst, epochs=epoch)
model.save_weights('VGG16_model.h5')
model.summary()
# 查看可训练的变量
print(model.trainable_variables)
file = open('./weight.txt', 'w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
############# show ##################
# 显示训练集和测试集的ACC和Loss
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(121)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title("VGG16 bs=6 Accuracy")
plt.legend()
plt.subplot(122)
plt.plot(loss, label='Train Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title("VGG16 bs=6 Loss")
plt.legend()
plt.show()