# coding=utf-8
import itertools
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import Embedding, Dense, Dropout, Input
from tensorflow.keras.layers import Layer
class AFM_Layer(Layer):
def __init__(self, feature_size, embedding_dim, mode="sum",name=None, att_vector=8, activation='relu', dropout=0.8):
super(AFM_Layer, self).__init__()
self.embedding_weight = \
self.add_weight(shape=(feature_size, embedding_dim),
initializer=tf.initializers.glorot_normal(),
name="{}/embedding_weight".format(name))
self.mode = mode
if self.mode == "att":
self.attention_W = Dense(att_vector, activation=activation, use_bias=True)
self.attention_dense = Dense(1, activation=None)
self.dropout = Dropout(dropout)
self.dense = Dense(1, activation=None)
def call(self, inputs, **kwargs):
## (batch_size, feature_size, embedding_size)
feature_zize = inputs[1] ## (batch_size, feature_size)
emb_res = tf.nn.embedding_lookup(self.embedding_weight, inputs)
row = []
col = []
## 按照顺序依次选取进行交叉,就是序号
for r, c in itertools.combinations(range(len(feature_zize)), 2):
row.append(r)
col.append(c)
p = tf.gather(emb_res, row, axis=1) ## (batch_size, (f_size)*(f)/2, k)
q = tf.gather(emb_res, col, axis=1) ## (batch_size, (f_size)*(f)/2, k)
bi_interaction = p * q
# axis=1 reduce_sum reduce_mean 去除一个维度 (batch_size, embed_dim)
if self.mode == "sum":
outputs = tf.reduce_sum(bi_interaction, axis=1)
return outputs
if self.mode == "avg":
outputs = tf.reduce_mean(bi_interaction, axis=1)
return outputs
if self.mode == "att":
## 讲每个向量进行映射,然后再缩减为1为,然后归一化
a = self.attention_W(bi_interaction) # (None, (len(sparse) * len(sparse) - 1) / 2, t)
a = self.attention_dense(a) # (None, (len(sparse) * len(sparse) - 1) / 2, 1)
a_score = tf.nn.softmax(a, axis=1) # (None, (len(sparse) * len(sparse) - 1) / 2, 1)
outputs = tf.reduce_sum(bi_interaction * a_score, axis=1) # (None, embed_dim)
return outputs
if __name__ == '__main__':
inputs = tf.constant([[1, 2, 3, 5, 6],
[8, 2, 3, 68, 11],
[20, 2, 2, 3, 6],
[39, 2, 99, 23, 34]], dtype=tf.int32)
print("inputs", inputs)
afm_layer = AFM_Layer(100, 10, name="AFM_layer")
res = afm_layer(inputs, mode="att")
print("res:", res)
afm_layer
最新推荐文章于 2022-09-14 17:31:14 发布