RIPGeo代码理解(四)model.py( RIPGeo的核心源代码)

本文详细解析RIPGeo的model.py,该模型基于PyTorch实现,涉及图神经网络(GNN)的步骤,包括特征处理、邻接矩阵构建、属性相似度计算等,用于学习地标与目标之间的关系。
摘要由CSDN通过智能技术生成

 代码链接:RIPGeo代码实现

├── lib # 包含模型(model)实现文件
    │        |── layers.py # 注意力机制的代码。
    │        |── model.py # TrustGeo的核心源代码。
    │        |── sublayers.py # layer.py的支持文件。
    │        |── utils.py # 辅助函数。

一、导入各种模块和神经网络类

from .layers import *
import torch
import torch.nn as nn

这段代码是一个 Python 模块,包含了一些导入语句和定义了一个神经网络模型的类。

from .layers import *:导入了当前模块所在目录中的 layers 模块中的所有内容。* 表示导入所有的内容。

二、RIPGeo类定义(NN模型)

class RIPGeo(nn.Module):
    def __init__(self, dim_in, dim_z, dim_med, dim_out, collaborative_mlp=True):
        super(RIPGeo, self).__init__()

        # RIPGeo
        self.att_attribute = SimpleAttention1(temperature=dim_z ** 0.5,
                                             d_q_in=dim_in,
                                             d_k_in=dim_in,
                                             d_v_in=dim_in + 2,
                                             d_q_out=dim_z,
                                             d_k_out=dim_z,
                                             d_v_out=dim_z)

        if collaborative_mlp:
            self.pred = SimpleAttention2(temperature=dim_z ** 0.5,
                                        d_q_in=dim_in * 3 + 4,
                                        d_k_in=dim_in,
                                        d_v_in=2,
                                        d_q_out=dim_z,
                                        d_k_out=dim_z,
                                        d_v_out=2,
                                        drop_last_layer=False)

        else:
            self.pred = nn.Sequential(
                nn.Linear(dim_z, dim_med),
                nn.ReLU(),
                nn.Linear(dim_med, dim_out)
            )

        self.collaborative_mlp = collaborative_mlp

        # calculate A
        self.gamma_1 = nn.Parameter(torch.ones(1, 1))
        self.gamma_2 = nn.Parameter(torch.ones(1, 1))
        self.gamma_3 = nn.Parameter(torch.ones(1, 1))
        self.alpha = nn.Parameter(torch.ones(1, 1))
        self.beta = nn.Parameter(torch.zeros(1, 1))

        # transform in Graph
        self.w_1 = nn.Linear(dim_in + 2, dim_in + 2)
        self.w_2 = nn.Linear(dim_in + 2, dim_in + 2)

    def forward(self, lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay):
        """
        :param lm_X: feature of landmarks [..., 30]: 14 attribute + 16 measurement
        :param lm_Y: location of landmarks [..., 2]: longitude + latitude
        :param tg_X: feature of targets [..., 30]
        :param tg_Y: location of targets [..., 2]
        :param lm_delay: delay from landmark to the common router [..., 1]
        :param tg_delay: delay from target to the common router [..., 1]
        :return:
        """

        N1 = lm_Y.size(0)
        N2 = tg_Y.size(0)
        ones = torch.ones(N1 + N2 + 1).cuda()
        lm_feature = torch.cat((lm_X, lm_Y), dim=1)
        tg_feature_0 = torch.cat((tg_X, torch.zeros(N2, 2).cuda()), dim=1)
        router_0 = torch.mean(lm_feature, dim=0, keepdim=True)
        all_feature_0 = torch.cat((lm_feature, tg_feature_0, router_0), dim=0)

        '''
        star-GNN
        properties:
        1. single directed graph: feature of <landmarks> will never be updated.
        2. the target IP will receive from surrounding landmarks from two ways: 
            (1) attribute similarity-based one-hop propagation;
            (2) delay measurement-based two-hop propagation via the common router;
        '''
        # GNN-step 1
        adj_matrix_0 = torch.diag(ones)
        delay_score = torch.exp(-self.gamma_1 * (self.alpha * lm_delay + self.beta))

        rou2tar_score_0 = torch.exp(-self.gamma_2 * (self.alpha * tg_delay + self.beta)).reshape(N2)

        # feature
        _, attribute_score = self.att_attribute(tg_X, lm_X, lm_feature)
        attribute_score = torch.exp(attribute_score)

        adj_matrix_0[N1:N1 + N2, :N1] = attribute_score
        adj_matrix_0[-1, :N1] = delay_score
        adj_matrix_0[N1:N1 + N2:, -1] = rou2tar_score_0

        degree_0 = torch.sum(adj_matrix_0, dim=1)
        degree_reverse_0 = 1.0 / degree_0
        degree_matrix_reverse_0 = torch.diag(degree_reverse_0)

        degree_mul_adj_0 = degree_matrix_reverse_0 @ adj_matrix_0
        step_1_all_feature = self.w_1(degree_mul_adj_0 @ all_feature_0)

        tg_feature_1 = step_1_all_feature[N1:N1 + N2, :]
        router_1 = step_1_all_feature[-1, :].reshape(1, -1)

        # GNN-step 2
        adj_matrix_1 = torch.diag(ones)
        rou2tar_score_1 = torch.exp(-self.gamma_3 * (self.alpha * tg_delay + self.beta)).reshape(N2)
        adj_matrix_1[N1:N1 + N2:, -1] = rou2tar_score_1

        all_feature_1 = torch.cat((lm_featur
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值