Tensorflow 2.1 自定义网络层的方法及其注意事项
具体定义方法在tensorflow官网 tf2.1 API 中定义的很清楚,但其中有需要注意的事项,是官方API没有写清楚的
定义方法
官网给出的定义方式:
class MyDenseLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs):
super(MyDenseLayer, self).__init__()
self.num_outputs = num_outputs
def build(self, input_shape):
self.kernel = self.add_weight(