图神经网络推荐算法实战(附python代码)--2020CIKM论文复现

1.概述

近期读到2020CIKM的POI推荐研究方向相关的一篇论文(Learning Graph-Based Geographical Latent Representation for Point-of-Interest Recommendation),想看其如何实现过程,但原作者并没有提供实现代码,因此对这篇论文的代码尝试进行了复现,供各位进行参考。

原论文:https://dl.acm.org/doi/pdf/10.1145/3340531.3411905

同时在论文代码实现过程中部分参考了LightGCN源码的实现,lightGCN源码可在原论文中找到源码链接:

LightGCN:https://dl.acm.org/doi/pdf/10.1145/3397271.3401063

代码中Loader类为加载数据,GGLR类为poi的outgoing和ingoing传播,User类为user-poi graph这一部分的实现,GPR为整合GGLR和User,此模型因需对POI-POI图整体进行更新,耗内存较大,建议使用内存较高的电脑或者服务器跑。

本文不对模型进行更多阐述,感兴趣的小伙伴可自行阅读论文。因能力有限,如果复现代码中存在错误,欢迎各位提出指正。

2.代码

以下即是完整代码,因不能放数据,代码+数据可访问github库:GitHub - fanglin1/GPR_POI_Recommendation: (Unofficial implement)A graph method for point-of-interest recommendation

代码相对于原文来说有一个地方在写的时候有稍许差异,即在算损失函数的时候采用了MAE而不是MSE,因为在前面算地理距离影响时原文用到了一个exp函数,这导致loss极易变为inf,采用MSE加剧了这一情况。如果要采用MSE的话可以设置一个地理影响的上限和下限,我则在预测频率时设置了一个上限和一个下限(freq变量),各位小伙伴可以试一下效果。

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from scipy.sparse import csr_matrix
from time import time
import scipy.sparse as sp
import torch
from tqdm import tqdm
import torch.nn as nn
import os


class Loader():
    def __init__(self, dataset_name='Gowalla', path='./dataset/', save_path='./data_save/distance_graph'):

        path = path + dataset_name + "/"
        save_path = save_path + dataset_name + '.npy'
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        size_path = path + dataset_name + '_data_size.txt'
        self.check_in_file = path + dataset_name + '_checkins.txt'
        self.poi_geo_file = path + dataset_name + '_poi_coos.txt'
        train_file = path + dataset_name + '_train.txt'
        test_file = path + dataset_name + '_test.txt'
        val_file = path + dataset_name + '_tune.txt'
        self.save_path = save_path
        trainUniqueUsers, trainItem, trainUser = [], [], []
        testUniqueUsers, testItem, testUser = [], [], []
        valUniqueUsers, valItem, valUser = [], [], []
        with open(size_path) as f:
            size_lis = f.readline().strip('\n').split('\t')
            self.user_num = int(size_lis[0])
            self.poi_num = int(size_lis[1])
        # 读取数据
        with open(train_file) as f:
            for l in f.readlines():
                if len(l) > 0:
                    l = l.strip('\n').split('\t')
                    uid = int(l[0])
                    poi = int(l[1])
                    freq = int(l[2])
                    if uid not in trainUniqueUsers:
                        trainUniqueUsers.append(uid)
                    trainUser.extend([uid] * int(freq))
                    trainItem.extend([poi] * int(freq))
        self.trainUniqueUsers = np.array(trainUniqueUsers)
        self.trainUser = np.array(trainUser)
        self.trainItem = np.array(trainItem)
        self.trainSize = len(self.trainItem)

        with open(test_file) as f:
            for l in f.readlines():
                if len(l) > 0:
                    l = l.strip('\n').split('\t')
                    uid = int(l[0])
                    poi = int(l[1])
                    freq = int(l[2])
                    if uid not in testUniqueUsers:
                        testUniqueUsers.append(uid)
                    testUser.extend([uid] * int(freq))
                    testItem.extend([poi] * int(freq))
        self.testUniqueUsers = np.array(testUniqueUsers)
        self.testUser = np.array(testUser)
        self.testItem = np.array(testItem)

        with open(val_file) as f:
            for l in f.readlines():
                if len(l) > 0:
                    l = l.strip('\n').split('\t')
                    uid = int(l[0])
                    poi = int(l[1])
                    freq = int(l[2])
                    if uid not in valUniqueUsers:
                        valUniqueUsers.append(uid)
                    valUser.extend([uid] * int(freq))
                    valItem.extend([poi] * int(freq))
        self.valUniqueUsers = np.array(valUniqueUsers)
        self.valUser = np.array(valUser)
        self.valItem = np.array(valItem)

        print('数据读取完成,开始构建图')
        # 用户-poi的无向图,值为1
        self.UserPoiNet = csr_matrix((np.ones(len(self.trainUser)), (self.trainUser, self.trainItem)),
                                     shape=(self.user_num, self.poi_num))

        self.sparse_user_poi_graph = self.get_graph().to(self.device)
        print('用户-POI图构建完成,开始构建训练集、测试集、验证集')

        #  二维列表,列表中每个元素为为一个一维列表,存储一个用户访问过的poi
        self.user_pois = self.getUserPois(list(range(self.user_num)))
        self.test_data = self.build_test(testItem, testUser)
        self.val_data = self.build_test(valItem, valUser)
        # self.graph = self.get_graph()
        print('开始构建有向图及POI距离矩阵')
        train_check_in = self.split_data()

        self.directed_poi_graph, self.directed_in_graph, self.target_graph = self.get_poi_graph(train_check_in)
        if os.path.exists(save_path):
            is_saved = True
        else:
            #             os.mkdir('./data_save')
            is_saved = False
        if is_saved:
            print('开始加载距离图矩阵')
            self.distance_graph = torch.FloatTensor(np.load(self.save_path)).to(self.device)
        else:
            print('开始计算距离图矩阵')
            self.distance_graph = self.calculate_poi_distance(saved=is_saved)
        print('数据初始化完成')

    # 获取用户访问过的记录
    def getUserPois(self, users):
        posItems = []
        for user in users:
            # print(user,self.UserItemNet[user].shape)
            posItems.append(self.UserPoiNet[user].nonzero()[1])
        return posItems

    # 获取测试集与验证集
    def build_test(self, poi_data, user_data):

        test_data = {}
        for i, poi in enumerate(poi_data):
            user = user_data[i]
            if test_data.get(user):
                test_data[user].append(poi)
            else:
                test_data[user] = [poi]
        return test_data

    # 建立无向图
    def get_graph(self):
        adj_mat = sp.dok_matrix((self.user_num, self.poi_num), dtype=np.float32)
        adj_mat = adj_mat.tolil()
        R = self.UserPoiNet.tolil()
        adj_mat[:self.user_num, :self.poi_num] = R

        adj_mat = adj_mat.tocsr()

        Graph = self._convert_sp_mat_to_sp_tensor(adj_mat)

        return Graph

    # 转换稀疏矩阵为tensor
    def _convert_sp_mat_to_sp_tensor(self, X):
        coo = X.tocoo().astype(np.float32)
        row = torch.Tensor(coo.row).long()
        col = torch.Tensor(coo.col).long()
        index = torch.stack([row, col])
        data = torch.FloatTensor(coo.data)
        return torch.sparse.FloatTensor(index, data, torch.Size(coo.shape))

    # 距离计算
    def haversine_distances_fast(self, X, Y):

        dlon = Y[:, 1] - X[:, 1][:, np.newaxis]
        dlat = Y[:, 0] - X[:, 0][:, np.newaxis]
        a = np.sin(dlat / 2) ** 2 + np.cos(X[:, 0])[:, np.newaxis] * np.cos(Y[:, 0]) * np.sin(dlon / 2) ** 2
        c = 2 * np.arcsin(np.sqrt(a))
        km = 6367 * c
        return km

    # 如果已保存了距离文件直接加载即可
    def calculate_poi_distance(self, saved=False):
        geo = \
        pd.read_csv(self.poi_geo_file, header=None, sep='\t', names=['poi_id', 'lat', 'lon']).sort_values(by='poi_id',
                                                                                                          ascending=True)[
            ['lat', 'lon']].astype(np.float32)

        lat_rad = np.radians(geo.loc[:, 'lat'].values)
        lon_rad = np.radians(geo.loc[:, 'lon'].values)
        undisturbed_lat_rad = np.radians(geo.loc[:, 'lat'].values)
        undisturbed_lon_rad = np.radians(geo.loc[:, 'lon'].values)

        dist_matrix = self.haversine_distances_fast(np.vstack([lat_rad, lon_rad]).T,
                                                    np.vstack([undisturbed_lat_rad, undisturbed_lon_rad]).T)
        dist_matrix = np.maximum(dist_matrix, 0.01)
        dist_matrix = np.minimum(dist_matrix, 100)

        if not saved:
            np.save(self.save_path, dist_matrix)
        return torch.FloatTensor(dist_matrix).to(self.device)

    # 划分原始数据
    def split_data(self):
        data_df1 = pd.read_csv(self.check_in_file, sep='\t', header=None, names=['user_id', 'poi_id', 'time'])
        data_df1['time'] = pd.to_datetime(data_df1['time'], unit='s')
        users = data_df1['user_id'].unique()
        total_train = []

        for user in tqdm(users):
            user_df = data_df1[data_df1['user_id'] == user].sort_values(['time'], ascending=[True]).reset_index(
                drop=True)
            user_data_num = user_df.shape[0]
            # 用户的前百分之70为测试集中数据
            train_df = user_df[:int(user_data_num * 0.7)]
            total_train.append(train_df)
        total_train = pd.concat(total_train, axis=0).reset_index(drop=True)
        return total_train

    # 获取有向图
    def get_poi_graph(self, data_train):
        graph = np.zeros((self.poi_num, self.poi_num))
        users = data_train['user_id'].unique()
        for user in tqdm(users):
            user_data = data_train.loc[data_train['user_id'] == user].sort_values('time', ascending=True).reset_index(
                drop=True)
            index = user_data.index
            for i in range(1, len(index)):
                # 计算时间间隔
                time_interval = user_data.loc[index[i], 'time'] - user_data.loc[index[i - 1], 'time']
                if time_interval.days < 1:
                    poi1 = user_data.loc[index[i - 1], 'poi_id']
                    poi2 = user_data.loc[index[i], 'poi_id']
                    if poi1 != poi2:
                        graph[poi1, poi2] += 1
        # 结点的度
        D1 = graph.sum(axis=1).reshape(-1, 1)
        D1[D1 == 0] = 1
        print('图的维度', graph.shape, '度的维度', D1.shape)
        in_graph = graph.T
        D2 = in_graph.sum(axis=1).reshape(-1, 1)
        D2[D2 == 0] = 1
        norm_graph = self._convert_sp_mat_to_sp_tensor(csr_matrix(graph / D1)).coalesce().to(self.device)
        norm_in_graph = self._convert_sp_mat_to_sp_tensor(csr_matrix(in_graph / D2)).coalesce().to(self.device)

        return norm_graph, norm_in_graph, torch.FloatTensor(graph).to(self.device)


loader = Loader()

import torch.nn.functional as F


class GGLR(nn.Module):
    def __init__(self, d, out_graph, in_graph):
        super(GGLR, self).__init__()
        self.out_weight = nn.Parameter(torch.FloatTensor(d, d))
        self.in_weight = nn.Parameter(torch.FloatTensor(d, d))
        self.bias1 = nn.Parameter(torch.FloatTensor(d))
        self.bias2 = nn.Parameter(torch.FloatTensor(d))
        self.out_g = out_graph
        self.in_g = in_graph

    def forward(self, x1, x2):
        # print(x1.dtype,x2.dtype,self.out_g.dtype,self.out_weight.dtype)

        # assert not torch.any(torch.isnan(self.out_g))
        # assert not torch.any(torch.isnan(self.in_g))
        # assert not torch.any(torch.isnan(x1))
        # assert not torch.any(torch.isnan(x2))
        # print('x1\n',x1)
        # print('out_weight\n', self.out_weight)
        # print(torch.mm(x1,self.out_weight))
        # print(torch.sparse.mm(self.out_g,torch.mm(x1,self.out_weight)))
        # assert not torch.any(torch.isnan(torch.mm(x1,self.out_weight)))
        # assert not torch.any(torch.sparse.mm(self.out_g,torch.mm(x1,self.out_weight)))

        k_emb_outgoing = torch.sparse.mm(self.out_g, torch.mm(x1, self.out_weight)) + self.bias1
        k_emb_ingoing = torch.sparse.mm(self.in_g, torch.mm(x2, self.in_weight)) + self.bias2
        return F.relu(k_emb_outgoing), F.relu(k_emb_ingoing)


class User(nn.Module):
    def __init__(self, dim, user_poi):
        super(User, self).__init__()
        self.user_weight = nn.Parameter(torch.FloatTensor(dim, dim))
        self.poi_weight = nn.Parameter(torch.FloatTensor(dim, dim))
        self.bias = nn.Parameter(torch.FloatTensor(dim))
        self.adj = user_poi

    def forward(self, all_out_going_embs, user_embs, selected_u):
        selected_users_adj = self.adj.index_select(0, selected_u)
        poi_message = torch.sparse.mm(selected_users_adj, torch.mm(all_out_going_embs, self.poi_weight))
        user_message = torch.mm(user_embs, self.user_weight)
        user_update_embs = poi_message + user_message + self.bias
        return F.relu(user_update_embs)


class GPR(nn.Module):
    def __init__(self, data, latent_dim, layer_num):
        super(GPR, self).__init__()
        self.user_num = data.user_num
        self.poi_num = data.poi_num
        self.dim = latent_dim
        self.layer_num = layer_num
        self.adj = data.sparse_user_poi_graph.to(loader.device)
        self.target_adj = data.target_graph

        # 出度图和入度图
        self.out_going_adj = data.directed_poi_graph
        self.in_going_adj = data.directed_in_graph
        # poi更新
        self.poi_module = nn.ModuleList(
            [GGLR(self.dim, self.out_going_adj, self.in_going_adj) for _ in range(self.layer_num)])
        # 用户更新
        self.user_module = nn.ModuleList([User(self.dim, self.adj) for _ in range(self.layer_num)])

        # 用户和poi的embedding
        self.user_embedding = nn.Embedding(num_embeddings=self.user_num, embedding_dim=self.dim)
        self.ingoing_poi_embedding = nn.Embedding(num_embeddings=self.poi_num, embedding_dim=self.dim)
        self.outgoing_poi_embedding = nn.Embedding(num_embeddings=self.poi_num, embedding_dim=self.dim)

        # 距离参数
        self.a = nn.Parameter(torch.FloatTensor(1))
        self.b = nn.Parameter(torch.FloatTensor(1))
        self.c = nn.Parameter(torch.FloatTensor(1))
        self.distance_graph = data.distance_graph
        self.distance_weight = nn.Parameter(torch.FloatTensor(self.dim, self.dim))

    def bpr_loss(self, users_emb, pos_emb, neg_emb):

        reg_loss = (1 / 2) * (users_emb.norm(2).pow(2) +
                              pos_emb.norm(2).pow(2) +
                              neg_emb.norm(2).pow(2)) / float(len(users_emb))
        pos_scores = torch.mul(users_emb, pos_emb)
        pos_scores = torch.sum(pos_scores, dim=1)
        neg_scores = torch.mul(users_emb, neg_emb)
        neg_scores = torch.sum(neg_scores, dim=1)

        loss = torch.mean(torch.nn.functional.softplus(neg_scores - pos_scores))

        return loss, reg_loss

    def mean_square_loss(self, pre, labels):
        # print('pre.shape',pre.shape,labels.shape)
        loss_fun = torch.nn.L1Loss()
        return loss_fun(pre,labels)

    def test_rating(self, users):
        all_user_embs = self.user_embedding.weight
        all_ingoing_poi_embs = self.ingoing_poi_embedding.weight
        all_outgoing_poi_embs = self.outgoing_poi_embedding.weight
        selected_users = all_user_embs[users]
        all_layer_user_embs = []
        all_layer_in_embs = []
        for i in range(self.layer_num):
            all_outgoing_poi_embs, all_ingoing_poi_embs = self.poi_module[i](all_outgoing_poi_embs,
                                                                             all_ingoing_poi_embs)
            selected_users = self.user_module[i](all_outgoing_poi_embs, selected_users, users)
            all_layer_user_embs.append(selected_users)
            all_layer_in_embs.append(all_ingoing_poi_embs)

        concat_user = torch.concat(all_layer_user_embs, 1)

        concat_in = torch.concat(all_layer_in_embs, 1)
        user_scores = torch.matmul(concat_user, concat_in.T)
        return user_scores

    def forward(self, users, pos_poi, neg_poi):
        # print('users',users,'pos_poi',pos_poi,'neg_poi',neg_poi)
        all_user_embs = self.user_embedding.weight
        all_ingoing_poi_embs = self.ingoing_poi_embedding.weight
        all_outgoing_poi_embs = self.outgoing_poi_embedding.weight
        # print('all_user_embs\n',all_user_embs)
        # print('users\n',users)
        selected_users = all_user_embs[users]
        # print('selected_users.shape',selected_users.shape)

        all_layer_user_embs = []
        all_layer_out_embs = []
        all_layer_in_embs = []
        for i in range(self.layer_num):
            all_outgoing_poi_embs, all_ingoing_poi_embs = self.poi_module[i](all_outgoing_poi_embs,
                                                                             all_ingoing_poi_embs)
            #             print(all_outgoing_poi_embs.deivice,selected_users.deivice,users.device)
            #             print('assert not torch.any(torch.isnan(all_outgoing_poi_embs))')
            #             assert not torch.any(torch.isnan(all_outgoing_poi_embs))
            #             # print('assert not torch.any(torch.isnan(all_ingoing_poi_embs))')
            #             assert not torch.any(torch.isnan(all_ingoing_poi_embs))

            selected_users = self.user_module[i](all_outgoing_poi_embs, selected_users, users)
            # print('assert not torch.any(torch.isnan(selected_users))')
            # assert not torch.any(torch.isnan(selected_users))
            all_layer_user_embs.append(selected_users)
            all_layer_out_embs.append(all_outgoing_poi_embs)
            all_layer_in_embs.append(all_ingoing_poi_embs)

        # 物理地理影响
        physical_geography_effect = self.a * (self.distance_graph.pow(self.b)) * torch.exp(self.c * self.distance_graph)
        physical_geography_effect = torch.where(physical_geography_effect > 1e9, 1e9, physical_geography_effect)
        physical_geography_effect = torch.where(physical_geography_effect < -1e9, -1e9, physical_geography_effect)
        # assert not torch.any(torch.isnan(physical_geography_effect))
        # print('physical_geography_effect',physical_geography_effect,'\nmax',torch.max(physical_geography_effect),'min',torch.min(physical_geography_effect),'a',self.a,'b',self.b,'c',self.c)
        # 频次预测值
        freq = physical_geography_effect * (
            torch.mm(torch.mm(all_outgoing_poi_embs, self.distance_weight), all_ingoing_poi_embs.T))



        mse_loss = self.mean_square_loss(freq, self.target_adj)
        # 拼接
        concat_user = torch.concat(all_layer_user_embs, 1)
        concat_out = torch.concat(all_layer_out_embs, 1)
        concat_in = torch.concat(all_layer_in_embs, 1)
        # 被选的用户
        concat_in_pos = concat_in[pos_poi]
        concat_in_neg = concat_in[neg_poi]
        bpr_loss, bpr_reg_loss = self.bpr_loss(concat_user, concat_in_pos, concat_in_neg)

        # 用户评分
        user_scores = torch.matmul(concat_user, concat_in.T)
        # print(mse_loss, bpr_loss, bpr_reg_loss)

        return user_scores, mse_loss, bpr_loss, bpr_reg_loss

from torch import optim
import multiprocessing
import torch

def get_pos_neg_data(dataset):
    all_data = dataset.user_pois
    users = np.random.randint(0, dataset.user_num, dataset.trainSize)
    samples =[]
    for ind,user in enumerate(users):
        user_pois = all_data[user]
        if len(user_pois)==0:
            continue
        poi_index = np.random.randint(0,len(user_pois))
        # 正样本
        pos_poi_id = user_pois[poi_index]
        # 负样本
        while True:
            neg_poi_id = np.random.randint(0,dataset.poi_num)
            if neg_poi_id in user_pois:
                continue
            else:
                break
        samples.append((user,pos_poi_id,neg_poi_id))
    return np.array(samples)

def shuffle_data(*args):
    shuffle_indices = np.arange(len(args[0]))
    if len(set(len(x) for x in args)) != 1:
        raise ValueError('输入的数组大小不一致')

    return tuple(x[shuffle_indices] for x in args)

def split_batch(*args,**kwargs):
    batch_size = kwargs['batch_size']
    # 测试时只需要用户
    if len(args) == 1:
        tensor = args[0]
        for i in range(0, len(tensor), batch_size):
            yield tensor[i:i + batch_size]
    # 训练时需要正例和负例
    else:
        for i in range(0, len(args[0]), batch_size):
            yield tuple(x[i:i + batch_size] for x in args)

# 推荐标准度量

def getLabel(test_data, pred_data):
    r = []
    for i in range(len(test_data)):
        groundTrue = test_data[i]
        predictTopK = pred_data[i]
        pred = list(map(lambda x: x in groundTrue, predictTopK))
        pred = np.array(pred).astype("float")
        r.append(pred)
    return np.array(r).astype('float')

# precision和recall度量
def RecallPrecision_ATk(ground, r, k):
    right_pred = r[:, :k].sum(1)
    precis_n = k
    recall_n = np.array([len(ground[i]) for i in range(len(ground))])
    recall = np.sum(right_pred/recall_n)
    precis = np.sum(right_pred)/precis_n
    return {'recall': recall, 'precision': precis}

# NDCG度量
def NDCGatK_r(test_data,r,k):
    """
    Normalized Discounted Cumulative Gain
    rel_i = 1 or 0, so 2^{rel_i} - 1 = 1 or 0
    """
    assert len(r) == len(test_data)
    pred_data = r[:, :k]

    test_matrix = np.zeros((len(pred_data), k))
    for i, items in enumerate(test_data):
        length = k if k <= len(items) else len(items)
        test_matrix[i, :length] = 1
    max_r = test_matrix
    idcg = np.sum(max_r * 1./np.log2(np.arange(2, k + 2)), axis=1)
    dcg = pred_data*(1./np.log2(np.arange(2, k + 2)))
    dcg = np.sum(dcg, axis=1)
    idcg[idcg == 0.] = 1.
    ndcg = dcg/idcg
    ndcg[np.isnan(ndcg)] = 0.
    return np.sum(ndcg)

def test_one_batch(X):
    K_num=[20]
    sorted_items = X[0].numpy()
    groundTrue = X[1]
    r = getLabel(groundTrue, sorted_items)
    pre,recall,ndcg =[],[],[]
    for k in K_num:
        res =  RecallPrecision_ATk(groundTrue, r, k+1)
        pre.append(res['precision'])
        recall.append(res['recall'])
        ndcg.append(NDCGatK_r(groundTrue,r,k))
    return {'recall':np.array(recall),
            'precision':np.array(pre),
            'ndcg':np.array(ndcg)}

def test_performance(dataset, rec_model, b_size,topk):
    results = {'precision': np.zeros(len(topk)),
               'recall': np.zeros(len(topk)),
               'ndcg': np.zeros(len(topk))}
    user_data = dataset.test_data
    rec_model.eval()
    cores_num = multiprocessing.cpu_count() // 2
    pool = multiprocessing.Pool(cores_num)
    with torch.no_grad():
        users_id = list(user_data.keys())
        user_lis =[]
        rating_lis = []
        ground_truth = []
        total_batch = len(users_id) //  b_size+ 1
        for batch_users in tqdm(split_batch(users_id,batch_size=b_size)):
            all_pos = dataset.getUserPois(batch_users)
            groundTrue = [user_data[u] for u in batch_users]
            batch_users_tensor = torch.Tensor(batch_users).long().to(loader.device)

            u_rating = rec_model.test_rating(batch_users_tensor)
            _,rating_k = torch.topk(u_rating, k=max(topk))

            del u_rating
            user_lis.append(batch_users)
            rating_lis.append(rating_k.cpu())
            ground_truth.append(groundTrue)
        assert total_batch == len(user_lis)
        X = zip(rating_lis, ground_truth)

        pre_results = []
        for x in X:
            pre_results.append(test_one_batch(x))

        for result in pre_results:
            results['recall'] += result['recall']
            results['precision'] += result['precision']
            results['ndcg'] += result['ndcg']
        results['recall'] /= float(len(users_id))
        results['precision'] /= float(len(users_id))
        results['ndcg'] /= float(len(users_id))

        print(results)
        return results

# 开始训练
epochs = 50
bpr_batch_size = 2048
test_batch_size = 256
learning_rate = 0.001
layer_num = 3
latent_dim = 32
regular1 = 0.2
regular2 = 1e-4
max_K = [20]
GPR_net = GPR(loader,latent_dim,layer_num)
GPR_net.to(loader.device)
for p in GPR_net.parameters():
    if p.dim() > 1:
#         nn.init.xavier_uniform_(self.W.data)
        nn.init.xavier_uniform_(p)
    else:
        nn.init.uniform_(p)
# for name, param in GPR_net.named_parameters():
#     if param.requires_grad:
#         print(name, param.data)

# 是否需要加入模型的xavier参数初始化
optimizer = optim.Adam(GPR_net.parameters(), lr=learning_rate)
for epoch in range(epochs):
    s = get_pos_neg_data(loader)
    user = torch.tensor(s[:,0]).long().to(loader.device)
    pos = torch.tensor(s[:,1]).long().to(loader.device)
    neg = torch.tensor(s[:,2]).long().to(loader.device)
    loss_ = 0
#     print(user.device)
    shuffle_user, shuffle_pos,shuffle_neg = shuffle_data(user, pos, neg)
    print('start_test')
    per_result = test_performance(loader, GPR_net, test_batch_size, max_K)
    print('start train epoch:',epoch)
    # count=0
    batch_num = user.shape[0]//bpr_batch_size+1
    for batch_index, (batch_u,batch_pos,batch_neg) in enumerate(split_batch(shuffle_user, shuffle_pos, shuffle_neg, batch_size=bpr_batch_size)):
        optimizer.zero_grad()

        u_scores,batch_mse_loss,batch_bpr_loss,batch_bpr_reg_loss = GPR_net(batch_u,batch_pos,batch_neg)
        total_loss = batch_mse_loss+regular1*batch_bpr_reg_loss+regular2*batch_bpr_loss
        loss_+=total_loss
        total_loss.backward()
        optimizer.step()
        # count+=1
        if batch_index%10==0:
            print(batch_index+1,"/",batch_num,batch_mse_loss,batch_bpr_loss,batch_bpr_reg_loss)
    # if epoch%5==0:

    print(epoch,'epoch loss:', loss_)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值