在创建自定义网络层类时,需要继承自layers.Layer
基类;创建自定义的网络类,需要继承自 keras.Model
基类,这样产生的自定义类才能够方便的利用Layer/Model 基类提供的参数管理功能,同时也能够与其他的标准网络层类交互使用。
一、自定义网络层
对于自定义的网络层,需要实现初始化 __inti__
方法和前向传播逻辑call
方法
假设我们需要一个没有偏置的全连接层,即bias 为0,同时固定激活函数为ReLU 函数。
class MyDense(layers.Layer):
# 自定义网络层
def __init__(self, inp_dim, outp_dim):
super(MyDense, self).__init__()
# 创建权值张量并添加到类管理列表中,设置为需要优化
self.kernel = self.add_variable('w', [inp_dim, outp_dim],trainable=True)
net = MyDense(4,3) #创建输入为4,输出为3节点的自定义层
print(net.variables,net.trainable_variables)
通过修改为self.kernel = self.add_variable('w', [inp_dim, outp_dim], trainable=False)
,我们可以设置张量不需要被优化,此时再来观测张量的管理状态:
看出此时张量并不会被 trainable_variables管理
完成自定义类的初始化工作后,可以设计自定义类的前项运算逻辑
def call(self, inputs, training=None):
# 实现自定义类的前向计算逻辑
# X@W
out = inputs @ self.kernel
# 执行激活函数运算
out = tf.nn.relu(out)
return out
如上所示,自定义类的前向运算逻辑需要实现在call(inputs, training)
函数中,其中inputs 代表输入,由用户在调用时传入;training 参数用于指定模型的状态:training 为True 时执行训练模式,training 为False 时执行测试模式,默认参数为None,即测试模式。由于全连接层的训练模式和测试模式逻辑一致,此处不需要额外处理。对于部份测试模式和训练模式不一致的网络层,需要根据training 参数来设计需要执行的逻辑。
二、自定义网络
在完成了我们自定义的全连接层类之后,我们基于上述的“无偏置的全连接层”来实
现MNIST 手写数字图片模型的创建。
自定义的类可以和其他标准类一样,通过Sequential 容器方便地包裹成一个网络模
型:
network = keras.Sequential([MyDense(784, 256), # 使用自定义的层
MyDense(256, 128),
MyDense(128, 64),
MyDense(64, 32),
MyDense(32, 10)])
network.build(input_shape=(None, 28*28))
network.summary()
更普遍地,我们可以继承基类来实现任意逻辑的自定义网络类。下面我们来创建自定义网络类,首先创建并继承Model 基类,分布创建对应的网络层对象:
# 自定义网络类,继承自Model 基类
def __init__(self):
super(MyModel, self).__init__()
# 完成网络内需要的网络层的创建工作
self.fc1 = MyDense(28*28, 256)
self.fc2 = MyDense(256, 128)
self.fc3 = MyDense(128, 64)
self.fc4 = MyDense(64, 32)
self.fc5 = MyDense(32, 10)
#然后实现自定义网络的前向运算逻辑:
def call(self, inputs, training=None):
# 自定义前向运算逻辑
x = self.fc1(inputs)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)
x = self.fc5(x)
return x
这个例子可以直接使用第一种方式通过Sequential 容器包裹。但是由于Sequential 在前向传播是依次调用每个网络层的前向传播函数,灵活性一般,而自定义网络的前向逻辑可以任意定制,两者各有优缺点