[1905.07854] KGAT: Knowledge Graph Attention Network for Recommendation (arxiv.org)
LunaBlack/KGAT-pytorch (github.com)
目录
3.2 Attentive Embedding Propagation Layers
1、背景
CF方法,基于相似用户或者相似商品属性推荐,无法利用属性等各种side information,例如u1和u2相似,则可能会推荐i2
基于特征的SL模型,例如FM/NFM/Wide&Deep可以利用side-information,i1和i2有相同属性e1则推荐i2。
但是基于特征的SL模型单独的建模每个实例,没有建模实例之间的交互,无法从集体行为提取有用信息。例如u1很难对u1很难建模
尽管e1是连接导演和演员字段的桥梁。因此,我们认为这些方法没有充分探索高阶连通性,并且没有触及组合高阶关系
为了解决基于特征的SL模型的局限性,可以将知识图与用户-项目图的混合结构称为协同知识图(CKG),但是也有挑战:1)与目标用户具有高阶关系的节点随着订单规模的增加而急剧增加,这会给模型带来计算过载;2)高阶关系对预测的贡献不平等,这需要模型仔细加权(或选择)它们。
已经有一些基于CKG模型进行推荐的方法:
(1)基于路径的方法
提取携带高阶信息的路径并将其输入到预测模型,但是第一阶段的路径选择对性能影响很大,而且定义有效元路径需要领域知识,工作过量很大。
(2)基于正则化的方法
基于正则化的方法设计额外损失项捕获KG结构,正则化推荐模型。联合训练推荐和KGC两个任务,两个任务间共享item embedding。这些方法不是直接将高阶关系插入为推荐而优化的模型中,而是以隐式的方式对它们进行编码。由于缺乏显式建模,既不能保证捕获远程连接,也不能解释高阶建模的结果。
考虑到上述局限性,作者开发一个能够以高效、显式和端到端方式利用KG中的高阶信息的模型。
2、任务定义
User-Item Bipartite Graph:在推荐场景中,通常有历史用户项目交互(例如,购买和点击),将交互数据表示为用户-物品二部图。
Knowledge Graph:除了user-item之间的交互外,还有物品的side-information(例如,物品属性和外部知识)。通常,这些辅助数据由真实世界的实体和额外的知识组成。作者将side-information组成知识图,
并有一个实体和item的对齐的集合,
Collaborative Knowledge Graph:定义了CKG的概念,它将用户行为和商品知识编码为一个统一的关系图首先将每个用户行为表示为三元组(u, interaction,i),其中y_ui = 1表示为用户u与项目i之间的附加关系interaction。然后基于entity-item对齐集,将user-item图与KG无缝集成为统一图
制定本文要解决的推荐任务:
•输入:协作知识图G,包括用户项二部图G1和知识图G2。
•输出:预测函数,预测用户u采用物品i的概率为。
3、模型
最主要的是将user-item交互也融入KG中计算
3.1 Embedding layer
使用TransR建模,首先将头实体eh和尾实体er利用由特定于关系的投影矩阵Wr投影到关系所在的空间,然后再计算三元组得分(投影的头实体+关系得到的向量,和投影的尾实体向量越相似越好,g越小越好)
损失采用对比学习方法
3.2 Attentive Embedding Propagation Layers
(1)权重计算
选择tanh作为非线性激活函数。这使得注意力得分依赖于关系r空间中eh和et之间的距离,为更接近的实体传播更多信息,为简单,只使用内积计算
接着使用softmax
最终的注意力得分能够建议应该给予哪些邻居节点更多的关注来捕获协作信号。在进行前向传播时,注意流会提示需要关注的部分数据,这可以视为推荐背后的解释。
(2)消息传递
为了表征实体h的一阶连通性结构,计算了h的自我网络的线性组合(自我网络,是h为头实体的三元组的集合)
(3)聚合
最后一个阶段是将实体自己本身的表征eh和它的自我网络表征e_Nh聚合为实体h的新表征——更正式地说,
- GCN Aggregator:将两个表征向量求和,并经过一个非线性激活函数
- GraphSage Aggregator :将两个表征向量拼接,经过一个非线性函数
- Bi-Interaction Aggregator:设计了两种函数,求和,以及两个特征向量元素积,并求和再通过一个非线性函数
(4)高阶传播
以上展示的是一阶传播和聚合的例子,很容易可以推广到高阶。
在第l步中,递归地将实体的表示表示为:
实体h在自我网络内传播的信息定义如下:
3.3 Model Prediction
执行L层后,得到用户节点u的多个表示,即;与项目节点i类似,得到。由于第l层的输出是图1中根于u(或i)的l的树结构深度的消息聚合,因此不同层的输出强调的是不同阶次的连通性信息。因此,采用层聚合机制,将每一步的表示concatence成单个向量
这样一来,不仅可以通过进行嵌入传播操作来丰富初始嵌入,还可以通过调整L来控制传播强度。
最后,对用户表征与物品表征进行内积,从而预测其匹配得分
3.4 Optimization
使用BPR损失优化推荐模型,它假设观察到的交互,这表明更多的用户偏好,应该被赋予比未观察到的更高的预测值:
(u,i)是观察到的真实的交互,(u,j)是未观察到(负样本)交互。
最后的损失函数,联合嵌入损失和推荐系统损失以及正则化
在训练时,交替优化KG嵌入损失和CF推荐损失。
4、部分代码解读
4.1 数据集
最后的CKG图由user-item交互二部图以及补充item信息的KG组成。
- train.txt/test.txt
训练数据集,由user id 和此user交互的itemID list组成。测试集和训练集中出现的交互为positive sample,没有观察到的交互作为negative sample。
- user_list.txt
由原来的user id,已经映射到CKG dataset中的id组成org_id remap_id
- item_list.txt
由原来的item id,已经映射到CKG dataset中的id,以及item在freebase中对应的id组成org_id remap_id freebase_id
- entity_list.txt
表明KG中的实体,由原来的在freebase中的entity id,已经映射到CKG dataset中的id组成org_id remap_id
- relation_list.txt
表明KG中的relation,由原来的在freebase中的relation id,已经映射到CKG dataset中的id组成org_id remap_id
4.2 数据集的处理
- 将kg添加逆关系,并对关系重新编号,做法是+2;将user-item交互图融入kg中,将user重新编号user id+实体总数,将user-item编码为0,将item-user编码为1
- 采样一个batch_size的数据,包含bath_size的user,以及为每一个user采样user-item交互的正样例,负样例
- 对产生的CKG图{h:(r,t)}进行采样生成负例正例。
loader_base.py
def load_cf(self, filename):
"""
函数说明:对user-item交互矩阵进行处理
Return:
(user, item) - user和其作用的item
user_dict - {user-id:[item1,item2,..],}
"""
user = []
item = []
user_dict = dict()
lines = open(filename, 'r').readlines()
for l in lines:
tmp = l.strip()
inter = [int(i) for i in tmp.split()]
if len(inter) > 1:
user_id, item_ids = inter[0], inter[1:]
item_ids = list(set(item_ids))
for item_id in item_ids:
user.append(user_id)
item.append(item_id)
user_dict[user_id] = item_ids
user = np.array(user, dtype=np.int32)
item = np.array(item, dtype=np.int32)
return (user, item), user_dict
def statistic_cf(self):
"""
获取user、item、训练集、测试集总数
"""
self.n_users = max(max(self.cf_train_data[0]), max(self.cf_test_data[0])) + 1
self.n_items = max(max(self.cf_train_data[1]), max(self.cf_test_data[1])) + 1
self.n_cf_train = len(self.cf_train_data[0])
self.n_cf_test = len(self.cf_test_data[0])
def load_kg(self, filename):
"""
读取最后的CKG数据,返回dataframe形式
"""
kg_data = pd.read_csv(filename, sep=' ', names=['h', 'r', 't'], engine='python')
kg_data = kg_data.drop_duplicates()
return kg_data
def sample_pos_items_for_u(self, user_dict, user_id, n_sample_pos_items):
"""
对user-item交互正样本进行采样
"""
pos_items = user_dict[user_id]
n_pos_items = len(pos_items)
sample_pos_items = []
while True:
if len(sample_pos_items) == n_sample_pos_items:
break
pos_item_idx = np.random.randint(low=0, high=n_pos_items, size=1)[0]
pos_item_id = pos_items[pos_item_idx]
if pos_item_id not in sample_pos_items:
sample_pos_items.append(pos_item_id)
return sample_pos_items
def sample_neg_items_for_u(self, user_dict, user_id, n_sample_neg_items):
"""
为user-item交互采样负样例
"""
pos_items = user_dict[user_id]
sample_neg_items = []
while True:
if len(sample_neg_items) == n_sample_neg_items:
break
neg_item_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
if neg_item_id not in pos_items and neg_item_id not in sample_neg_items:
sample_neg_items.append(neg_item_id)
return sample_neg_items
def generate_cf_batch(self, user_dict, batch_size):
"""
采样batch_size的user,并对对这些user采样正样本,负样本
"""
exist_users = user_dict.keys()
if batch_size <= len(exist_users):
batch_user = random.sample(exist_users, batch_size)
else:
batch_user = [random.choice(exist_users) for _ in range(batch_size)]
batch_pos_item, batch_neg_item = [], []
for u in batch_user:
# 为每一个采样的user生成一个正样例和一个负样例
batch_pos_item += self.sample_pos_items_for_u(user_dict, u, 1)
batch_neg_item += self.sample_neg_items_for_u(user_dict, u, 1)
batch_user = torch.LongTensor(batch_user)
batch_pos_item = torch.LongTensor(batch_pos_item)
batch_neg_item = torch.LongTensor(batch_neg_item)
return batch_user, batch_pos_item, batch_neg_item
def sample_pos_triples_for_h(self, kg_dict, head, n_sample_pos_triples):
"""
为融合user-item交互的CKG图采样正例
"""
pos_triples = kg_dict[head]
n_pos_triples = len(pos_triples)
sample_relations, sample_pos_tails = [], []
while True:
if len(sample_relations) == n_sample_pos_triples:
break
pos_triple_idx = np.random.randint(low=0, high=n_pos_triples, size=1)[0]
tail = pos_triples[pos_triple_idx][0]
relation = pos_triples[pos_triple_idx][1]
if relation not in sample_relations and tail not in sample_pos_tails:
sample_relations.append(relation)
sample_pos_tails.append(tail)
return sample_relations, sample_pos_tails
def sample_neg_triples_for_h(self, kg_dict, head, relation, n_sample_neg_triples, highest_neg_idx):
"""
为融合user-item交互的CKG图采样负例
"""
pos_triples = kg_dict[head]
sample_neg_tails = []
while True:
if len(sample_neg_tails) == n_sample_neg_triples:
break
tail = np.random.randint(low=0, high=highest_neg_idx, size=1)[0]
if (tail, relation) not in pos_triples and tail not in sample_neg_tails:
sample_neg_tails.append(tail)
return sample_neg_tails
def generate_kg_batch(self, kg_dict, batch_size, highest_neg_idx):
"""为训练集CKG中每一个头实体采样一个正例的(r,t),一个负例的t"""
exist_heads = kg_dict.keys()
if batch_size <= len(exist_heads):
batch_head = random.sample(exist_heads, batch_size)
else:
batch_head = [random.choice(exist_heads) for _ in range(batch_size)]
batch_relation, batch_pos_tail, batch_neg_tail = [], [], []
for h in batch_head:
relation, pos_tail = self.sample_pos_triples_for_h(kg_dict, h, 1)
batch_relation += relation
batch_pos_tail += pos_tail
neg_tail = self.sample_neg_triples_for_h(kg_dict, h, relation[0], 1, highest_neg_idx)
batch_neg_tail += neg_tail
batch_head = torch.LongTensor(batch_head)
batch_relation = torch.LongTensor(batch_relation)
batch_pos_tail = torch.LongTensor(batch_pos_tail)
batch_neg_tail = torch.LongTensor(batch_neg_tail)
return batch_head, batch_relation, batch_pos_tail, batch_neg_tail
loader_kgat.py
def construct_data(self, kg_data):
"""
函数说明:创建逆边,并把user-item交互图融入,创建CKG
"""
# add inverse kg data
n_relations = max(kg_data['r']) + 1
inverse_kg_data = kg_data.copy()
inverse_kg_data = inverse_kg_data.rename({'h': 't', 't': 'h'}, axis='columns')
inverse_kg_data['r'] += n_relations
kg_data = pd.concat([kg_data, inverse_kg_data], axis=0, ignore_index=True, sort=False)
# re-map user id
kg_data['r'] += 2
self.n_relations = max(kg_data['r']) + 1
self.n_entities = max(max(kg_data['h']), max(kg_data['t'])) + 1
self.n_users_entities = self.n_users + self.n_entities
# re-map user id = user-item中的id + num_entities
self.cf_train_data = (np.array(list(map(lambda d: d + self.n_entities, self.cf_train_data[0]))).astype(np.int32), self.cf_train_data[1].astype(np.int32))
self.cf_test_data = (np.array(list(map(lambda d: d + self.n_entities, self.cf_test_data[0]))).astype(np.int32), self.cf_test_data[1].astype(np.int32))
self.train_user_dict = {k + self.n_entities: np.unique(v).astype(np.int32) for k, v in self.train_user_dict.items()}
self.test_user_dict = {k + self.n_entities: np.unique(v).astype(np.int32) for k, v in self.test_user_dict.items()}
# add interactions to kg data
# 将user-item交互数据融入kg中user交互item的关系编码为0,item-user交互编码为1
cf2kg_train_data = pd.DataFrame(np.zeros((self.n_cf_train, 3), dtype=np.int32), columns=['h', 'r', 't'])
cf2kg_train_data['h'] = self.cf_train_data[0]
cf2kg_train_data['t'] = self.cf_train_data[1]
inverse_cf2kg_train_data = pd.DataFrame(np.ones((self.n_cf_train, 3), dtype=np.int32), columns=['h', 'r', 't'])
inverse_cf2kg_train_data['h'] = self.cf_train_data[1]
inverse_cf2kg_train_data['t'] = self.cf_train_data[0]
self.kg_train_data = pd.concat([kg_data, cf2kg_train_data, inverse_cf2kg_train_data], ignore_index=True)
self.n_kg_train = len(self.kg_train_data)
# construct kg dict
h_list = []
t_list = []
r_list = []
self.train_kg_dict = collections.defaultdict(list)
self.train_relation_dict = collections.defaultdict(list)
for row in self.kg_train_data.iterrows():
h, r, t = row[1]
h_list.append(h)
t_list.append(t)
r_list.append(r)
self.train_kg_dict[h].append((t, r))
self.train_relation_dict[r].append((h, t))
self.h_list = torch.LongTensor(h_list)
self.t_list = torch.LongTensor(t_list)
self.r_list = torch.LongTensor(r_list)
4.3 模型
- 权重的计算:内积计算相似性,越相似的尾实体,则应该传递更多消息,权重应该更大
- 消息的传递和聚合
聚合ego-netework嵌入的加权和以及自身嵌入
- 损失函数:包含CF的损失和KGC的损失,以及参数的正则化部分
- 预测
将多层的消息传递聚合结果拼接起来,然后进行内积运算,得到用户点击某物品的概率
KGAT.py
import torch
import torch.nn as nn
import torch.nn.functional as F
def _L2_loss_mean(x):
return torch.mean(torch.sum(torch.pow(x, 2), dim=1, keepdim=False) / 2.)
class Aggregator(nn.Module):
def __init__(self, in_dim, out_dim, dropout, aggregator_type):
super(Aggregator, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.dropout = dropout
self.aggregator_type = aggregator_type
self.message_dropout = nn.Dropout(dropout)
self.activation = nn.LeakyReLU()
if self.aggregator_type == 'gcn':
self.linear = nn.Linear(self.in_dim, self.out_dim) # W in Equation (6)
nn.init.xavier_uniform_(self.linear.weight)
elif self.aggregator_type == 'graphsage':
self.linear = nn.Linear(self.in_dim * 2, self.out_dim) # W in Equation (7)
nn.init.xavier_uniform_(self.linear.weight)
elif self.aggregator_type == 'bi-interaction':
self.linear1 = nn.Linear(self.in_dim, self.out_dim) # W1 in Equation (8)
self.linear2 = nn.Linear(self.in_dim, self.out_dim) # W2 in Equation (8)
nn.init.xavier_uniform_(self.linear1.weight)
nn.init.xavier_uniform_(self.linear2.weight)
else:
raise NotImplementedError
def forward(self, ego_embeddings, A_in):
"""
ego_embeddings: (n_users + n_entities, in_dim)
A_in: (n_users + n_entities, n_users + n_entities), torch.sparse.FloatTensor
"""
# Equation (3)
side_embeddings = torch.matmul(A_in, ego_embeddings)
if self.aggregator_type == 'gcn':
# Equation (6) & (9)
embeddings = ego_embeddings + side_embeddings
embeddings = self.activation(self.linear(embeddings))
elif self.aggregator_type == 'graphsage':
# Equation (7) & (9)
embeddings = torch.cat([ego_embeddings, side_embeddings], dim=1)
embeddings = self.activation(self.linear(embeddings))
elif self.aggregator_type == 'bi-interaction':
# Equation (8) & (9)
sum_embeddings = self.activation(self.linear1(ego_embeddings + side_embeddings))
bi_embeddings = self.activation(self.linear2(ego_embeddings * side_embeddings))
embeddings = bi_embeddings + sum_embeddings
embeddings = self.message_dropout(embeddings) # (n_users + n_entities, out_dim)
return embeddings
class KGAT(nn.Module):
def __init__(self, args,
n_users, n_entities, n_relations, A_in=None,
user_pre_embed=None, item_pre_embed=None):
super(KGAT, self).__init__()
self.use_pretrain = args.use_pretrain
self.n_users = n_users
self.n_entities = n_entities
self.n_relations = n_relations
self.embed_dim = args.embed_dim
self.relation_dim = args.relation_dim
self.aggregation_type = args.aggregation_type
self.conv_dim_list = [args.embed_dim] + eval(args.conv_dim_list)
self.mess_dropout = eval(args.mess_dropout)
self.n_layers = len(eval(args.conv_dim_list))
self.kg_l2loss_lambda = args.kg_l2loss_lambda
self.cf_l2loss_lambda = args.cf_l2loss_lambda
self.entity_user_embed = nn.Embedding(self.n_entities + self.n_users, self.embed_dim)
self.relation_embed = nn.Embedding(self.n_relations, self.relation_dim)
self.trans_M = nn.Parameter(torch.Tensor(self.n_relations, self.embed_dim, self.relation_dim))
if (self.use_pretrain == 1) and (user_pre_embed is not None) and (item_pre_embed is not None):
other_entity_embed = nn.Parameter(torch.Tensor(self.n_entities - item_pre_embed.shape[0], self.embed_dim))
nn.init.xavier_uniform_(other_entity_embed)
entity_user_embed = torch.cat([item_pre_embed, other_entity_embed, user_pre_embed], dim=0)
self.entity_user_embed.weight = nn.Parameter(entity_user_embed)
else:
nn.init.xavier_uniform_(self.entity_user_embed.weight)
nn.init.xavier_uniform_(self.relation_embed.weight)
nn.init.xavier_uniform_(self.trans_M)
self.aggregator_layers = nn.ModuleList()
for k in range(self.n_layers):
self.aggregator_layers.append(Aggregator(self.conv_dim_list[k], self.conv_dim_list[k + 1], self.mess_dropout[k], self.aggregation_type))
# A是邻接矩阵
self.A_in = nn.Parameter(torch.sparse.FloatTensor(self.n_users + self.n_entities, self.n_users + self.n_entities))
if A_in is not None:
self.A_in.data = A_in
self.A_in.requires_grad = False
def calc_cf_embeddings(self):
"""
计算多层的消息传递和聚合
"""
ego_embed = self.entity_user_embed.weight
all_embed = [ego_embed]
for idx, layer in enumerate(self.aggregator_layers):
ego_embed = layer(ego_embed, self.A_in)
norm_embed = F.normalize(ego_embed, p=2, dim=1)
all_embed.append(norm_embed)
# Equation (11)
all_embed = torch.cat(all_embed, dim=1) # (n_users + n_entities, concat_dim)
return all_embed
def calc_cf_loss(self, user_ids, item_pos_ids, item_neg_ids):
"""
user_ids: (cf_batch_size)
item_pos_ids: (cf_batch_size)
item_neg_ids: (cf_batch_size)
"""
all_embed = self.calc_cf_embeddings() # (n_users + n_entities, concat_dim)
user_embed = all_embed[user_ids] # (cf_batch_size, concat_dim)
item_pos_embed = all_embed[item_pos_ids] # (cf_batch_size, concat_dim)
item_neg_embed = all_embed[item_neg_ids] # (cf_batch_size, concat_dim)
# Equation (12)
pos_score = torch.sum(user_embed * item_pos_embed, dim=1) # (cf_batch_size)
neg_score = torch.sum(user_embed * item_neg_embed, dim=1) # (cf_batch_size)
# Equation (13)
# cf_loss = F.softplus(neg_score - pos_score)
cf_loss = (-1.0) * F.logsigmoid(pos_score - neg_score)
cf_loss = torch.mean(cf_loss)
l2_loss = _L2_loss_mean(user_embed) + _L2_loss_mean(item_pos_embed) + _L2_loss_mean(item_neg_embed)
loss = cf_loss + self.cf_l2loss_lambda * l2_loss
return loss
def calc_kg_loss(self, h, r, pos_t, neg_t):
"""
h: (kg_batch_size)
r: (kg_batch_size)
pos_t: (kg_batch_size)
neg_t: (kg_batch_size)
"""
r_embed = self.relation_embed(r) # (kg_batch_size, relation_dim)
W_r = self.trans_M[r] # (kg_batch_size, embed_dim, relation_dim)
h_embed = self.entity_user_embed(h) # (kg_batch_size, embed_dim)
pos_t_embed = self.entity_user_embed(pos_t) # (kg_batch_size, embed_dim)
neg_t_embed = self.entity_user_embed(neg_t) # (kg_batch_size, embed_dim)
r_mul_h = torch.bmm(h_embed.unsqueeze(1), W_r).squeeze(1) # (kg_batch_size, relation_dim)
r_mul_pos_t = torch.bmm(pos_t_embed.unsqueeze(1), W_r).squeeze(1) # (kg_batch_size, relation_dim)
r_mul_neg_t = torch.bmm(neg_t_embed.unsqueeze(1), W_r).squeeze(1) # (kg_batch_size, relation_dim)
# Equation (1)
pos_score = torch.sum(torch.pow(r_mul_h + r_embed - r_mul_pos_t, 2), dim=1) # (kg_batch_size)
neg_score = torch.sum(torch.pow(r_mul_h + r_embed - r_mul_neg_t, 2), dim=1) # (kg_batch_size)
# Equation (2)
# kg_loss = F.softplus(pos_score - neg_score)
kg_loss = (-1.0) * F.logsigmoid(neg_score - pos_score)
kg_loss = torch.mean(kg_loss)
l2_loss = _L2_loss_mean(r_mul_h) + _L2_loss_mean(r_embed) + _L2_loss_mean(r_mul_pos_t) + _L2_loss_mean(r_mul_neg_t)
loss = kg_loss + self.kg_l2loss_lambda * l2_loss
return loss
def update_attention_batch(self, h_list, t_list, r_idx):
"""
更新注意力权重
"""
r_embed = self.relation_embed.weight[r_idx]
W_r = self.trans_M[r_idx]
h_embed = self.entity_user_embed.weight[h_list]
t_embed = self.entity_user_embed.weight[t_list]
# Equation (4)
r_mul_h = torch.matmul(h_embed, W_r)
r_mul_t = torch.matmul(t_embed, W_r)
v_list = torch.sum(r_mul_t * torch.tanh(r_mul_h + r_embed), dim=1)
return v_list
def update_attention(self, h_list, t_list, r_list, relations):
device = self.A_in.device
rows = []
cols = []
values = []
for r_idx in relations:
index_list = torch.where(r_list == r_idx)
batch_h_list = h_list[index_list]
batch_t_list = t_list[index_list]
batch_v_list = self.update_attention_batch(batch_h_list, batch_t_list, r_idx)
rows.append(batch_h_list)
cols.append(batch_t_list)
values.append(batch_v_list)
rows = torch.cat(rows)
cols = torch.cat(cols)
values = torch.cat(values)
indices = torch.stack([rows, cols])
shape = self.A_in.shape
A_in = torch.sparse.FloatTensor(indices, values, torch.Size(shape))
# Equation (5)
A_in = torch.sparse.softmax(A_in.cpu(), dim=1)
self.A_in.data = A_in.to(device)
def calc_score(self, user_ids, item_ids):
"""
user_ids: (n_users)
item_ids: (n_items)
计算user点击item的得分
"""
all_embed = self.calc_cf_embeddings() # (n_users + n_entities, concat_dim)
user_embed = all_embed[user_ids] # (n_users, concat_dim)
item_embed = all_embed[item_ids] # (n_items, concat_dim)
# Equation (12)
cf_score = torch.matmul(user_embed, item_embed.transpose(0, 1)) # (n_users, n_items)
return cf_score
def forward(self, *input, mode):
if mode == 'train_cf':
return self.calc_cf_loss(*input)
if mode == 'train_kg':
return self.calc_kg_loss(*input)
if mode == 'update_att':
return self.update_attention(*input)
if mode == 'predict':
return self.calc_score(*input)
4.4 模型训练
主要包括交替训练CF与KGC两个任务,并在每次交替训练后更新消息传递的权重。
main_kgat.py
def train(args):
# seed
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
log_save_id = create_log_id(args.save_dir)
logging_config(folder=args.save_dir, name='log{:d}'.format(log_save_id), no_console=False)
logging.info(args)
# GPU / CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load data
data = DataLoaderKGAT(args, logging)
if args.use_pretrain == 1:
user_pre_embed = torch.tensor(data.user_pre_embed)
item_pre_embed = torch.tensor(data.item_pre_embed)
else:
user_pre_embed, item_pre_embed = None, None
# construct model & optimizer
model = KGAT(args, data.n_users, data.n_entities, data.n_relations, data.A_in, user_pre_embed, item_pre_embed)
if args.use_pretrain == 2:
model = load_model(model, args.pretrain_model_path)
model.to(device)
logging.info(model)
cf_optimizer = optim.Adam(model.parameters(), lr=args.lr)
kg_optimizer = optim.Adam(model.parameters(), lr=args.lr)
# initialize metrics
best_epoch = -1
best_recall = 0
Ks = eval(args.Ks)
k_min = min(Ks)
k_max = max(Ks)
epoch_list = []
metrics_list = {k: {'precision': [], 'recall': [], 'ndcg': []} for k in Ks}
# train model
for epoch in range(1, args.n_epoch + 1):
time0 = time()
model.train()
# train cf
time1 = time()
cf_total_loss = 0
n_cf_batch = data.n_cf_train // data.cf_batch_size + 1
# 交替训练CF与KGC
for iter in range(1, n_cf_batch + 1):
time2 = time()
# 采样一个cf_batch_size的user list,并为user list中的每一个user采样一个正样例和负样例。
cf_batch_user, cf_batch_pos_item, cf_batch_neg_item = data.generate_cf_batch(data.train_user_dict, data.cf_batch_size)
cf_batch_user = cf_batch_user.to(device)
cf_batch_pos_item = cf_batch_pos_item.to(device)
cf_batch_neg_item = cf_batch_neg_item.to(device)
cf_batch_loss = model(cf_batch_user, cf_batch_pos_item, cf_batch_neg_item, mode='train_cf')
if np.isnan(cf_batch_loss.cpu().detach().numpy()):
logging.info('ERROR (CF Training): Epoch {:04d} Iter {:04d} / {:04d} Loss is nan.'.format(epoch, iter, n_cf_batch))
sys.exit()
cf_batch_loss.backward()
cf_optimizer.step()
cf_optimizer.zero_grad()
cf_total_loss += cf_batch_loss.item()
if (iter % args.cf_print_every) == 0:
logging.info('CF Training: Epoch {:04d} Iter {:04d} / {:04d} | Time {:.1f}s | Iter Loss {:.4f} | Iter Mean Loss {:.4f}'.format(epoch, iter, n_cf_batch, time() - time2, cf_batch_loss.item(), cf_total_loss / iter))
logging.info('CF Training: Epoch {:04d} Total Iter {:04d} | Total Time {:.1f}s | Iter Mean Loss {:.4f}'.format(epoch, n_cf_batch, time() - time1, cf_total_loss / n_cf_batch))
# train kg
time3 = time()
kg_total_loss = 0
n_kg_batch = data.n_kg_train // data.kg_batch_size + 1
for iter in range(1, n_kg_batch + 1):
time4 = time()
kg_batch_head, kg_batch_relation, kg_batch_pos_tail, kg_batch_neg_tail = data.generate_kg_batch(data.train_kg_dict, data.kg_batch_size, data.n_users_entities)
kg_batch_head = kg_batch_head.to(device)
kg_batch_relation = kg_batch_relation.to(device)
kg_batch_pos_tail = kg_batch_pos_tail.to(device)
kg_batch_neg_tail = kg_batch_neg_tail.to(device)
kg_batch_loss = model(kg_batch_head, kg_batch_relation, kg_batch_pos_tail, kg_batch_neg_tail, mode='train_kg')
if np.isnan(kg_batch_loss.cpu().detach().numpy()):
logging.info('ERROR (KG Training): Epoch {:04d} Iter {:04d} / {:04d} Loss is nan.'.format(epoch, iter, n_kg_batch))
sys.exit()
kg_batch_loss.backward()
kg_optimizer.step()
kg_optimizer.zero_grad()
kg_total_loss += kg_batch_loss.item()
if (iter % args.kg_print_every) == 0:
logging.info('KG Training: Epoch {:04d} Iter {:04d} / {:04d} | Time {:.1f}s | Iter Loss {:.4f} | Iter Mean Loss {:.4f}'.format(epoch, iter, n_kg_batch, time() - time4, kg_batch_loss.item(), kg_total_loss / iter))
logging.info('KG Training: Epoch {:04d} Total Iter {:04d} | Total Time {:.1f}s | Iter Mean Loss {:.4f}'.format(epoch, n_kg_batch, time() - time3, kg_total_loss / n_kg_batch))
# 交替训练完一次更新注意力权重
# update attention
time5 = time()
# h_list/t_list/r_list是CKG图中所有的头实体、关系、尾实体列表
h_list = data.h_list.to(device)
t_list = data.t_list.to(device)
r_list = data.r_list.to(device)
relations = list(data.laplacian_dict.keys())
model(h_list, t_list, r_list, relations, mode='update_att')
logging.info('Update Attention: Epoch {:04d} | Total Time {:.1f}s'.format(epoch, time() - time5))
logging.info('CF + KG Training: Epoch {:04d} | Total Time {:.1f}s'.format(epoch, time() - time0))
# evaluate cf
if (epoch % args.evaluate_every) == 0 or epoch == args.n_epoch:
time6 = time()
_, metrics_dict = evaluate(model, data, Ks, device)
logging.info('CF Evaluation: Epoch {:04d} | Total Time {:.1f}s | Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
epoch, time() - time6, metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))
epoch_list.append(epoch)
for k in Ks:
for m in ['precision', 'recall', 'ndcg']:
metrics_list[k][m].append(metrics_dict[k][m])
best_recall, should_stop = early_stopping(metrics_list[k_min]['recall'], args.stopping_steps)
if should_stop:
break
if metrics_list[k_min]['recall'].index(best_recall) == len(epoch_list) - 1:
save_model(model, args.save_dir, epoch, best_epoch)
logging.info('Save model on epoch {:04d}!'.format(epoch))
best_epoch = epoch
# save metrics
metrics_df = [epoch_list]
metrics_cols = ['epoch_idx']
for k in Ks:
for m in ['precision', 'recall', 'ndcg']:
metrics_df.append(metrics_list[k][m])
metrics_cols.append('{}@{}'.format(m, k))
metrics_df = pd.DataFrame(metrics_df).transpose()
metrics_df.columns = metrics_cols
metrics_df.to_csv(args.save_dir + '/metrics.tsv', sep='\t', index=False)
# print best metrics
best_metrics = metrics_df.loc[metrics_df['epoch_idx'] == best_epoch].iloc[0].to_dict()
logging.info('Best CF Evaluation: Epoch {:04d} | Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
int(best_metrics['epoch_idx']), best_metrics['precision@{}'.format(k_min)], best_metrics['precision@{}'.format(k_max)], best_metrics['recall@{}'.format(k_min)], best_metrics['recall@{}'.format(k_max)], best_metrics['ndcg@{}'.format(k_min)], best_metrics['ndcg@{}'.format(k_max)]))