搭建模型
resnet18由4个基础的block组成,每个block有两层网络,block的结构如下。右边的箭头表示将x与f(x)相加,这样可以让经过一个block之后的效果至少不会比没有经过的差。这也是为什么resnet网络能够达到这么深的原因。
BasicBlock
1.每个block由两层网络组成,每层网络由卷积层、batchnormalization层和激活函数组成。
2.当stride等于1时,经过block的数据的size没有变化,直接相加。不等于1时size变化,需要一次下采样来让x的size与f(x)保持一致。
class BasicBlock(layers.Layer):
def __init__(self, filters_num, stride=1):
super(BasicBlock, self).__init__()
# 第一层
self.conv1 = layers.Conv2D(filters_num, (3, 3), strides=stride, padding='same')
self.bn1 = layers.BatchNormalization()
self.relu = layers.Activation('relu')
# 第二层
self.conv2 = layers.Conv2D(filters_num, (3, 3), strides=1, padding='same')
self.bn2 = layers.BatchNormalization()
# 如果strides等于1时输入与输出的size没有改变,不需要进行下采样操作
if stride != 1:
self.downsample = Sequential()
# 经过下采样让x与f(x)的size保持一致
self.downsample.add(layers.Conv2D(filters_num, (1, 1), strides=stride))
else:
self.downsample = lambda x: x
def call(self, inputs, training=None):
# 前向传播
out = self.conv1(inputs)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# 下采样
identity = self.downsample(inputs)
# f(x)+x
output = layers.add([out, identity])
# relu发在相加前后都可以
output = tf.nn.relu(output)
return output
ResNet
1.第一层是数据处理层,把输入的channel变为64传进下一层。
2.每一个block只有在第一层才可能有channel的变化。
3.layer_dims是一维矩阵,值代表对应的block有几层网络。
class ResNet(keras.Model):
def __init__(self, layer_dims, num_classes=100):
# layer_dims是一维矩阵,值代表对应的block有几层网络
super(ResNet, self).__init__()
# 第一层数据处理层
self.stem = Sequential([layers.Conv2D(64, (3, 3), strides=(1, 1)),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')])
self.layer1 = self.build_resblock(64, layer_dims[0])
self.layer2 = self.build_resblock(128, layer_dims[1], strides=(2, 2))
self.layer3 = self.build_resblock(256, layer_dims[2], strides=(2, 2))
self.layer4 = self.build_resblock(512, layer_dims[3], strides=(2, 2))
# 全局平均池化层, [b, 512, h, w]->[b, 512, 1, 1]
self.avgpool = layers.GlobalAvgPool2D()
# 全连接层
self.fc = layers.Dense(num_classes)
def call(self, inputs, training=None, mask=None):
out = self.stem(inputs)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avgpool(out)
out = self.fc(out)
return out
def build_resblock(self, filter_num, block, strides=1):
'''
:param filter_num:channel
:param block: 要搭建几层的resblock
:param strides:
:return:
'''
res_block = Sequential()
# res_block里的第一层,只有在第一层的时候才有size的变化
res_block.add(BasicBlock(filter_num, strides))
for _ in range(1, block):
res_block.add(BasicBlock(filter_num, stride=1))
return res_block
resnet18
def resnet18():
return ResNet([2, 2, 2, 2])
加载数据集
depth代表有几类,我用的是tf2内置的cifar100
def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32) / 255
y = tf.cast(y, dtype=tf.int32)
return x, y
def get_db(batch_size=128, shuffle=10000):
(x, y), (x_test, y_test) = datasets.cifar100.load_data()
y = tf.squeeze(y)
y_test = tf.squeeze(y_test)
y = tf.one_hot(y, depth=100)
y_test = tf.one_hot(y_test, depth=100)
db_train = tf.data.Dataset.from_tensor_slices((x, y))
db_train = db_train.map(preprocess).shuffle(shuffle).batch(batch_size=batch_size)
db_t = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_t = db_t.map(preprocess).shuffle(shuffle).batch(batch_size=batch_size)
return db_train, db_t
训练保存评估模型
在自动保存模型中,用period会报警说已经弃用,建议使用save_freq,save_freq是隔几个batch保存一次模型,period是隔几个epoch保存一次模型。
if __name__ == '__main__':
# 加载数据集
db, db_test = get_db()
model = resnet18()
model.build(input_shape=(None, 32, 32, 3))
# 打印模型信息
model.summary()
# 编译模型
model.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 设置tf.keras.callbacks.ModelCheckpoint回调实现自动保存模型
checkpoint_path = "weight/resnet18-{epoch:04d}.ckpt"
# period=1:每1个epochs 保存一次
# 用period会报警说已经弃用,建议使用save_freq,save_freq是隔几个batch保存一次模型,period是隔几个epoch保存一次模型
cp_callback = callbacks.ModelCheckpoint(
checkpoint_path, verbose=1, save_weights_only=True, period=1)
# 训练模型
model.fit(db, epochs=10, validation_data=db_test, validation_freq=1, callbacks=[cp_callback])
# 评估模型
score = model.evaluate(db_test)
print('Test score:', score[0])
print('Test accuracy:', score[1])
# 保存模型
model.save_weights('resnet18.ckpt')
# 评估加载后的模型
test_model = resnet18()
test_model.build(input_shape=(None, 32, 32, 3))
test_model.load_weights('resnet18.ckpt')
test_model.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
score = test_model.evaluate(db_test)
print('Test score:', score[0])
print('Test accuracy:', score[1])
输出
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) (None, 30, 30, 64) 2048
_________________________________________________________________
sequential_1 (Sequential) (None, 30, 30, 64) 148736
_________________________________________________________________
sequential_2 (Sequential) (None, 15, 15, 128) 526976
_________________________________________________________________
sequential_4 (Sequential) (None, 8, 8, 256) 2102528
_________________________________________________________________
sequential_6 (Sequential) (None, 4, 4, 512) 8399360
_________________________________________________________________
global_average_pooling2d (Gl multiple 0
_________________________________________________________________
dense (Dense) multiple 51300
=================================================================
Total params: 11,230,948
Trainable params: 11,223,140
Non-trainable params: 7,808
_________________________________________________________________
WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen.
Epoch 1/10
2020-07-04 10:48:25.869669: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10
2020-07-04 10:48:26.270991: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudnn.so.7
391/391 [==============================] - ETA: 0s - loss: 3.3788 - accuracy: 0.1892
391/391 [==============================] - 30s 76ms/step - loss: 3.3788 - accuracy: 0.1892 - val_loss: 5.8297 - val_accuracy: 0.0498
Epoch 2/10
390/391 [============================>.] - ETA: 0s - loss: 2.4293 - accuracy: 0.3623
391/391 [==============================] - 28s 72ms/step - loss: 2.4293 - accuracy: 0.3623 - val_loss: 3.5513 - val_accuracy: 0.2126
Epoch 3/10
390/391 [============================>.] - ETA: 0s - loss: 1.9221 - accuracy: 0.4764
391/391 [==============================] - 28s 73ms/step - loss: 1.9217 - accuracy: 0.4766 - val_loss: 2.8024 - val_accuracy: 0.3319
Epoch 4/10
390/391 [============================>.] - ETA: 0s - loss: 1.5254 - accuracy: 0.5679
391/391 [==============================] - 28s 73ms/step - loss: 1.5254 - accuracy: 0.5679 - val_loss: 2.0562 - val_accuracy: 0.4607
Epoch 5/10
390/391 [============================>.] - ETA: 0s - loss: 1.1593 - accuracy: 0.6600
391/391 [==============================] - 28s 73ms/step - loss: 1.1590 - accuracy: 0.6600 - val_loss: 2.3511 - val_accuracy: 0.4156
Epoch 6/10
390/391 [============================>.] - ETA: 0s - loss: 0.7823 - accuracy: 0.7653
391/391 [==============================] - 28s 72ms/step - loss: 0.7825 - accuracy: 0.7652 - val_loss: 2.2684 - val_accuracy: 0.4614
Epoch 7/10
390/391 [============================>.] - ETA: 0s - loss: 0.4402 - accuracy: 0.8674
391/391 [==============================] - 28s 73ms/step - loss: 0.4401 - accuracy: 0.8674 - val_loss: 2.3768 - val_accuracy: 0.4819
Epoch 8/10
390/391 [============================>.] - ETA: 0s - loss: 0.2215 - accuracy: 0.9355
391/391 [==============================] - 28s 73ms/step - loss: 0.2215 - accuracy: 0.9355 - val_loss: 2.3600 - val_accuracy: 0.4959
Epoch 9/10
390/391 [============================>.] - ETA: 0s - loss: 0.1309 - accuracy: 0.9646
391/391 [==============================] - 28s 72ms/step - loss: 0.1310 - accuracy: 0.9646 - val_loss: 2.6867 - val_accuracy: 0.4810
Epoch 10/10
390/391 [============================>.] - ETA: 0s - loss: 0.1630 - accuracy: 0.9509
391/391 [==============================] - 28s 72ms/step - loss: 0.1631 - accuracy: 0.9508 - val_loss: 3.1384 - val_accuracy: 0.4432
79/79 [==============================] - 1s 19ms/step - loss: 3.1384 - accuracy: 0.4432
Test score: 3.1384365558624268
Test accuracy: 0.4431999921798706
79/79 [==============================] - 2s 19ms/step - loss: 3.1384 - accuracy: 0.4432
Test score: 3.138436794281006
Test accuracy: 0.4431999921798706