tensorflow2.0 Keras自定义层与网络
使用keras通常的写法都是先调Sequential类(仅用于层的线性堆叠),如:
nn = Sequential([layers.Dense(128,activation='relu),
layers.Dense(64,activation='relu),
layers.Dense(10)])
有时用已有的layer并不能满足我们的需要,比如在层中加入:
x = x-1
这时就不得不加载自定义的Layer或Model,以适应各种需求,自定义层继承于keras.layers.Layer,自定义网络继承于keras.Model.
自定义层或者网络时需定义 init 方法以及 call 方法:
class MyDenseLayer(layers.Layer):
def __init__(self, input_dim, output_dim):
super(MyDenseLayer, self).__init__()
# add_variable已定义在父类中,用于添加trainable的变量
self.kernel = self.add_variable('weight', [input_dim, output_dim])
self.bias = self.add_variable('bias', [output_dim])
#这里使用inputs作为输入参数,training用于控制训练或测试
def