#coding=utf-8
import tensorflow as tf
from tensorflow.keras.layers import Dense, ReLU, Layer
class DeepCrossLater(Layer):
def __init__(self,dim_stack, name=None):
"""
:param hidden_unit: A list. Neural network hidden units.
:param dim_stack: A scalar. The dimension of inputs unit.
"""
super(DeepCrossLater, self).__init__()
# self.layer1 = Dense(units=hidden_unit, activation='relu')
# self.layer2 = Dense(units=dim_stack, activation=None)
self.relu = ReLU()
self.cross_num = 3
self.cross_weight = []
self.bais_weight = []
self.onehot_embedding = self.add_weight(shape=(10, 5),
initializer=tf.initializers.glorot_normal())
for i in range(3):
self.cross_weight.append(
self.add_weight(shape=(dim_stack, 1),
initializer=tf.initializers.glorot_normal(),
name="{}/cross_net_{}_weight".format(name,i)))
for i in range(3):
self.cross_weight.append(
self.bais_weight.append(self.add_weight(shape=(dim_stack, 1),
initializer=tf.initializers.glorot_normal(),
name="{}/cross_net_{}_bias".format(name, i))))
def call(self, inputs, **kwargs):
x0 = tf.nn.embedding_lookup(self.onehot_embedding, inputs)
######################## 实现方法一
xl = x0
for i in range(3):
xl = tf.matmul(x0, xl, transpose_a=True) ## batch_size, (5, 1)*(1, 5)*(5, 1)
xl = tf.matmul(xl, self.cross_weight[i])
xl = xl + self.bais_weight[i] + tf.transpose(x0, [0, 2, 1])
xl = tf.transpose(xl, [0, 2, 1])
print("method1 xl:", xl)
######################## 实现方法二
xl = tf.transpose(x0, [0, 2, 1])
for i in range(3):
xl = tf.matmul(x0, xl, transpose_a=True, transpose_b=True) ## batch_size, (5, 1)*(1, 5)*(5, 1)
xl = tf.matmul(xl, self.cross_weight[i])
xl = xl + self.bais_weight[i] + tf.transpose(x0, [0, 2, 1])
xl = tf.transpose(xl, [0, 2, 1])
print("method3 xl:", xl)
######################## 实现方法三
xl = x0
x0 = tf.transpose(x0, [0, 2, 1]) ## (2, 5, 1)
for i in range(3):
# i = 0
xl = tf.matmul(x0, xl) ## batch_size, (2, 5, 1)*(2, 1, 5) = (2, 5, 5)
xl = tf.matmul(xl, self.cross_weight[i]) ## (2, 5, 5)*(2, 5, 1) = (2, 5 ,1)
xl = xl + x0 + self.bais_weight[i]
xl = tf.transpose(xl, [0, 2, 1])
# xl = tf.transpose(xl, [0, 2, 1])
print("method2 xl:", xl)
return xl
if __name__ == '__main__':
input = tf.constant([[1],
[2]], dtype=tf.int32)
print("input:", input.shape)
ru_ins = DeepCrossLater(5, "CrossNet")
res = ru_ins(input)
print("res:", res)
deep_cross_layer
最新推荐文章于 2023-03-03 17:05:32 发布