残差网络结构如下图所示。
和自定义全连接网络一样,需要定义__init__,build,call函数: https://mp.csdn.net/console/editor/html/107713858
这里的输出out引用了上一步的conv结果:out = relu(conv)+conv
def call(self, input_tensor):
conv = tf.nn.conv2d(input_tensor, self.weight, strides=[1, 2, 2, 1], padding='SAME')
conv = tf.nn.bias_add(conv, self.bias)
out = tf.nn.relu(conv) + conv
return out
#输入参数包括:通道数目,卷积核大小
class MyLayer(tf.keras.layers.Layer):
def __init__(self,channel_num ,kernel_size):
self.kernel_size = kernel_size
self.channel_num = channel_num
super(MyLayer, self).__init__()
def build(self, input_shape):
self.weight = tf.Variable(tf.random.normal([channel_num,channel_num,input_shape[-1],self.kernel_size]))
self.bias = tf.Variable(tf.random.normal([self.kernel_size]))
super(MyLayer, self).build(input_shape) # Be sure to call this somewhere!
def call(self, input_tensor):
conv = tf.nn.conv2d(input_tensor, self.weight, strides=[1, 2, 2, 1], padding='SAME')
conv = tf.nn.bias_add(conv, self.bias)
out = tf.nn.relu(conv) + conv
return out
input_xs = tf.keras.Input([28,28,1])
conv = tf.keras.layers.Conv2D(32,3,padding="SAME",activation=tf.nn.relu)(input_xs)
#使用自定义的层替换Tensorflow 2.0的卷积层
conv = MyLayer(32,3)(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
conv = tf.keras.layers.Conv2D(64,3,padding="SAME",activation=tf.nn.relu)(conv)
conv = tf.keras.layers.MaxPool2D(strides=[1,1])(conv)
conv = tf.keras.layers.Conv2D(128,3,padding="SAME",activation=tf.nn.relu)(conv)
flat = tf.keras.layers.Flatten()(conv)
dense = tf.keras.layers.Dense(512, activation=tf.nn.relu)(flat)
logits = tf.keras.layers.Dense(10, activation=tf.nn.softmax)(dense)
model = tf.keras.Model(inputs=input_xs, outputs=logits)
print(model.summary())
model.compile(optimizer=tf.optimizers.Adam(1e-3), loss=tf.losses.categorical_crossentropy,metrics = ['accuracy'])
model.fit(train_dataset, epochs=10)
model.save("./saver/model.h5")
score = model.evaluate(test_dataset)
print("last score:",score)