在创建自定义网络层类时,需要继承自layers.Layer 基类;创建自定义的网络类,需要继承自 keras.Model 基类,这样产生的自定义类才能够方便的利用Layer/Model 基类提供的参数管理功能,同时也能够与其他的标准网络层类交互使用。
一:实现自定义网络层
实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:
> _init_()函数:参数初始化
>
> build()函数,可以获得输入张量的形状,权重变量初始化
>
> call()函数,构建网络结构,进行前向传播
实际上,你不必等到调用build()来创建网络结构,您也可以在_init_()中创建它们。 但是,在build()中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期的网络构建。 另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。
class MyDense(tf.keras.layers.Layer):
def __init__(self, n_outputs):
super(MyDense, self).__init__()
self.n_outputs = n_outputs
def build(self, input_shape):
self.kernel = self.add_variable('kernel',
shape=[int(input_shape[-1]),
self.n_outputs])
def call(self, input):
return tf.matmul(input, self.kernel)
layer = MyDense(10)
print(layer(tf.ones([6, 5])))
print(layer.trainable_variables)
使用的时候
//__init__(256)
mlayer=MyLayer(256)
//call(x)
mlayer(x)
其中input_shape是根据input来的
所以在重构build()的时候,layer的权重生成要与input数据shape来推导生成
二、自定义网络
机器学习模型中有很多是通过叠加不同的结构层组合而成的,如resnet的每个残差块就是“卷积+批标准化+残差连接”的组合。
在tensorflow2中要创建一个包含多个网络层的的结构,一般继承与tf.keras.Model类。
# 残差块
class ResnetBlock(tf.keras.Model):
def __init__(self, kernel_size, filters):
super(ResnetBlock, self).__init__(name='resnet_block')
# 每个子层卷积核数
filter1, filter2, filter3 = filters
# 三个子层,每层1个卷积加一个批正则化
# 第一个子层, 1*1的卷积
self.conv1 = tf.keras.layers.Conv2D(filter1, (1,1))
self.bn1 = tf.keras.layers.BatchNormalization()
# 第二个子层, 使用特点的kernel_size
self.conv2 = tf.keras.layers.Conv2D(filter2, kernel_size, padding='same')
self.bn2 = tf.keras.layers.BatchNormalization()
# 第三个子层,1*1卷积
self.conv3 = tf.keras.layers.Conv2D(filter3, (1,1))
self.bn3 = tf.keras.layers.BatchNormalization()
def call(self, inputs, training=False):
# 堆叠每个子层
x = self.conv1(inputs)
x = self.bn1(x, training=training)
x = self.conv2(x)
x = self.bn2(x, training=training)
x = self.conv3(x)
x = self.bn3(x, training=training)
# 残差连接
x += inputs
outputs = tf.nn.relu(x)
return outputs
resnetBlock = ResnetBlock(2, [6,4,9])
# 数据测试
print(resnetBlock(tf.ones([1,3,9,9])))
# 查看网络中的变量名
print([x.name for x in resnetBlock.trainable_variables])
从这个案例中可以发现call函数类似于torch中的forward函数
tf中的call等价于torch中forward