- 本质就是矩阵相乘 Amn *Bnp
- 这里会提取输入矩阵最后一层的dim 比如说是Amn的n
import keras
import tensorflow as tf
class Linear(keras.layers.Layer):
def __init__(self, units=32):
super(Linear, self).__init__()
self.units = units
def build(self, input_shape):
print(input_shape)
self.w = self.add_weight(
#本质就是矩阵相乘 Amn *Bnp
#这里会提取输入矩阵最后一层的dim 比如说是Amn的n
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,), initializer="random_normal", trainable=True
)
# super().build(input_shape)
def call(self, inputs):
# print(self.input_shape(inputs))
return tf.matmul(inputs, self.w) + self.b
x = tf.ones((2, 2))
linear_layer = Linear(6)
y = linear_layer(x)
print(y)
(2, 2)
tf.Tensor(
[[-0.0910622 -0.04033005 -0.0540841 -0.06019955 0.05445318 0.07133652]
[-0.0910622 -0.04033005 -0.0540841 -0.06019955 0.05445318 0.07133652]], shape=(2, 6), dtype=float32)