整体内容和前面相似,FM部分二阶交叉部分数学原理参考如下:
简单理解FM公式 - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/354994307
class FM(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
self.w0 = nn.Parameter(torch.zeros([1,]))
self.w1 = nn.Parameter(torch.rand([num_embeddings, 1]))
self.w2 = nn.Parameter(torch.rand([num_embeddings, embedding_dim]))
def forward(self, x):
first_order = torch.mm(x, self.w1)
second_order = 0.5 * torch.sum(
torch.pow(torch.mm(x, self.w2), 2) - torch.mm(torch.pow(x, 2), torch.pow(self.w2, 2)),
dim=1,
keepdim=True
)
return self.w0 + first_order + second_order
class DNN(nn.Module):
def __init__(self, hidden_units):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(in_features, out_features, bias=True) for in_features, out_features in zip(hidden_units[:-1], hidden_units[1:])
])
def forward(self, x):
for layer in self.layers:
x = F.relu(layer(x))
return x
class DeepFM(nn.Module):
def __init__(self, features_info, hidden_units, embedding_dim):
super().__init__()
# 解析特征信息
self.dense_features, self.sparse_features, self.sparse_features_nunique = features_info
# 解析拿到所有 数值型 和 稀疏型特征信息
self.__dense_features_nums = len(self.dense_features)
self.__sparse_features_nums = len(self.sparse_features)
# embedding
self.embeddings = nn.ModuleDict({
"embed_" + key : nn.Embedding(num_embeds, embedding_dim)
for key, num_embeds in self.sparse_features_nunique.items()
})
stack_dim = self.__dense_features_nums + self.__sparse_features_nums * embedding_dim
hidden_units.insert(0, stack_dim)
self.fm = FM(stack_dim, embedding_dim)
self.dnn = DNN(hidden_units)
self.dnn_last_linear = nn.Linear(hidden_units[-1], 1, bias=False)
def forward(self, x):
# 从输入x中单独拿出 sparse_input 和 dense_input
dense_inputs, sparse_inputs = x[:, :self.__dense_features_nums], x[:, self.__dense_features_nums:]
sparse_inputs = sparse_inputs.long()
embedding_feas = [self.embeddings["embed_" + key](sparse_inputs[:, idx]) for idx, key in enumerate(self.sparse_features)]
embedding_feas = torch.cat(embedding_feas, dim=-1)
input_feas = torch.cat([embedding_feas, dense_inputs], dim=-1)
fm = self.fm(input_feas)
dnn = self.dnn_last_linear(self.dnn(input_feas))
return F.sigmoid(fm + dnn)