Tensorflow2自定义层和激活函数
网上大多数都是这种方法:https://blog.csdn.net/weixin_39875161/article/details/104678867
在网上搜了许多,都没找我想要的自定义方法,所以自己摸索着写了个,也不知道对不对,反正参数是会更新。在此做个记录,以免以后忘记了。
带参数的自定义激活函数
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.util import dispatch
class TestRelu(Layer):
def __init__(self):
super(FlexRelu, self).__init__()
@tf.function
@keras_export('keras.backend.Trelu')
@dispatch.add_dispatch_support
def call(self, inputs, **kwargs):
x = inputs*self.alpha
return tf.math.maximum(x, 0)
def build(self, input_shape):
self.alpha = tf.Variable(tf.random.normal(shape=(1,), mean=0, stddev=0.5), dtype=tf.float32,
name='TRelu_alpha')
alpha是参数
带参数的自定义层
import tensorflow as tf
from tensorflow.keras import layers
class NewDropout(layers.Layer):
def __init__(self):
super(NewDropout, self).__init__()
def build(self, input_shape):
# self.zero_tensor = tf.zeros(input_shape, dtype=tf.float32)
self.alpha = tf.Variable(tf.random.normal(shape=(1,), mean=0, stddev=0.5), dtype=tf.float32, name='Dropout_alpha')
def call(self, inputs, **kwargs):
x = inputs - self.alpha
x = tf.where(x > 0, x, 0)
print('self.alpha', self.alpha)
return x
使用
model = Sequential([
layers.Dense(32, input_shape=(None, 100)),
TestRelu(),
layers.Dense(32),
NewDropout(),
TestRelu(),
layers.Dense(10),
Activation('softmax'),
])
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
for param in model.trainable_variables:
if param.name == 'test_relu/TRelu_alpha:0':
print('11111', param.name, param.shape, param.numpy)
data = np.random.random((1000, 100))
labels = np.random.randint(10, size=(1000, 1))
one_hot_labels = utils.to_categorical(labels, num_classes=10)
model.fit(data, one_hot_labels, epochs=400, batch_size=32)
for param in model.trainable_variables:
# if param.name == 'new_dropout/Dropout_alpha:0':
# print('111110000', param.name, param.shape, param.numpy)
if param.name == 'test_relu/TRelu_alpha:0':
print('111110000', param.name, param.shape, param.numpy)
现在还没看懂@tf.function, @keras_export(‘keras.backend.trelu’), @dispatch.add_dispatch_support到底起什么作用,以后再看吧。
如果有人看到请留个言,告诉下,不胜感激!