VGG的网络架构
import tensorflow as tf
def vgg_block(num_conv,num_filters):
#序列模型
blk=tf.keras.models.Sequential()
#遍历卷积层
for _ in range(num_conv):
#设置卷积层
blk.add(tf.keras.layers.Conv2D(num_filters,kernel_size=3,padding='same',activation='relu'))
#设置池化层
blk.add(tf.keras.layers.MaxPool2D(pool_size=2,strides=2))
return blk
def vgg(conv_arch):
#序列模型
net=tf.keras.models.Sequential()
#生成卷积部分
for (num_conv,num_filters) in conv_arch:
net.add(vgg_block(num_conv,num_filters))
#全连接层
net.add(tf.keras.models.Sequential([
#展平
tf.keras.layers.Flatten(),
#全连接层
tf.keras.layers.Dense(4096,activation='relu'),
#随机失活
tf.keras.layers.Dropout(0.5),
#全连接层
tf.keras.layers.Dense(4096,activation='relu'),
#随机失活
tf.keras.layers.Dropout(0.5),
#输出层
tf.keras.layers.Dense(10,activation='softmax')
]))
return net
#卷积层的参数
conv_arch=((2,64),(2,128),(3,256),(3,512),(3,512))
net=vgg(conv_arch)
x=tf.random.uniform((1,224,224,1))
y=net(x)
net.summary()
手写数字识别
数据读取
见alex net
模型编译
见alex net
模型训练
见alex net
模型评估
见alex net