这是DataWhale组队学习任务三的内容,本次学习的模型是DeepFM。
前面两次学习了DeepCrossing和Wide&Deep模型,这两个深度模型中DeepCrossing是只使用了深度模型,而Wide&Deep则是将深度模型和线性模型结合起来,让模型同时拥有了泛化和记忆能力。Wide&Deep模型的wide部分通过人工判断特征组合进行输入,必需十分熟悉业务才能很好的做出判断。在模型的输出部分直接将wide部分的低阶特征和deep部分学习的高阶特征组合,这样并不能很好的结合高阶特征和低阶特征。
DeepFM
DeepFM模型在Wide&Deep上做出了进一步的改进,主要是将Wide&Deep的wide部分替换成了FM。FM模型利用两个特征embedding做内积,在传统的线性模型的基础上有了更大的改进,可以学习到高阶的特征。而模型另一边还是deep部分。模型整体结构图如下所示:
可以看到,模型的FM部分和deep部分都是用到了embedding,最下面是数值型稀疏特征的onehot向量,随后对onehot向量做embedding,稀疏embedding中的每一个小块对应的是一个onehot向量。FM和deep部分的输出最后通过一个sigmoid函数得到预测结果。
模型实现
def DeepFM(linear_feature_columns, dnn_feature_columns):
# 构建输入层,即所有特征对应的Input()层,这里使用字典的形式返回,方便后续构建模型
dense_input_dict, sparse_input_dict = build_input_layers(linear_feature_columns + dnn_feature_columns)
# 将linear部分的特征中sparse特征筛选出来,后面用来做1维的embedding
linear_sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), linear_feature_columns))
# 构建模型的输入层,模型的输入层不能是字典的形式,应该将字典的形式转换成列表的形式
# 注意:这里实际的输入与Input()层的对应,是通过模型输入时候的字典数据的key与对应name的Input层
input_layers = list(dense_input_dict.values()) + list(sparse_input_dict.values())
# linear_logits由两部分组成,分别是dense特征的logits和sparse特征的logits
linear_logits = get_linear_logits(dense_input_dict, sparse_input_dict, linear_sparse_feature_columns)
# 构建维度为k的embedding层,这里使用字典的形式返回,方便后面搭建模型
# embedding层用户构建FM交叉部分和DNN的输入部分
embedding_layers = build_embedding_layers(dnn_feature_columns, sparse_input_dict, is_linear=False)
# 将输入到dnn中的所有sparse特征筛选出来
dnn_sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns))
fm_logits = get_fm_logits(sparse_input_dict, dnn_sparse_feature_columns, embedding_layers) # 只考虑二阶项
# 将所有的Embedding都拼起来,一起输入到dnn中
dnn_logits = get_dnn_logits(sparse_input_dict, dnn_sparse_feature_columns, embedding_layers)
# 将linear,FM,dnn的logits相加作为最终的logits
output_logits = Add()([linear_logits, fm_logits, dnn_logits])
# 这里的激活函数使用sigmoid
output_layers = Activation("sigmoid")(output_logits)
model = Model(input_layers, output_layers)
return model
代码实现如上所示。
可以看到,相比于Wide&Deep模型,多了一个FM部分,FM部分的输入是embedding。
代码运行结果如下:
参考资料:
https://github.com/datawhalechina/team-learning-rs/blob/master/DeepRecommendationModel/DeepFM.md
https://github.com/datawhalechina/team-learning-rs/blob/master/RecommendationSystemFundamentals/04%20FM.md