TrustGeo代码理解(四)model.py( TrustGeo的核心源代码)

本文详细解析TrustGeo的model.py,包括模型结构、输入处理、参数计算及前向传播过程。重点关注TrustGeo类的定义,如__init__()、输入处理函数、输出计算函数及模型的前向传播步骤。通过对图视图和属性视图的处理,模型能够输出一系列参数,用于后续的训练和评估。
摘要由CSDN通过智能技术生成

代码链接:https://github.com/ICDM-UESTC/TrustGeo

 ├── lib # 包含模型(model)实现文件
    │        |── layers.py # 注意力机制的代码。
    │        |── model.py # TrustGeo的核心源代码。
    │        |── sublayers.py # layer.py的支持文件。
    │        |── utils.py # 辅助函数,包括视图融合的代码

一、导入各种模块和神经网络类

from math import gamma
from re import L
from .layers import *
import torch
import torch.nn as nn
import torch.nn.functional as Func
import numpy as np

这段代码是一个 Python 模块,包含了一些导入语句和定义了一个神经网络模型的类。

该块代码与RIPGeo部分一致

RIPGeo代码理解(四)model.py( RIPGeo的核心源代码)-CSDN博客

不同之处在于:

1、from math import gamma:导入了 gamma 函数,这是 Python 标准库中 math 模块中的一个函数,用于计算伽玛函数。
2、from re import L:导入了 L,这看起来是一个导入错误。通常来说,应该是导入正则表达式相关的模块,比如 import re。不过,这行可能是一个错误,可能需要修改。(好像没什么用)
3、import torch.nn.functional as Func:导入了 PyTorch 库中的相关模块。functional 模块包含了一些与神经网络相关的函数。
4、import numpy as np:导入了 NumPy 库,NumPy 是一个用于科学计算的 Python 库,提供了大量用于数组操作的函数。

二、TrustGeo类定义(NN模型)

class TrustGeo(nn.Module):
    def __init__(self, dim_in):
        super(TrustGeo, self).__init__()
        self.dim_in = dim_in
        self.dim_z = dim_in + 2

        # TrustGeo
        self.att_attribute = SimpleAttention(temperature=self.dim_z ** 0.5,
                                             d_q_in=self.dim_in,
                                             d_k_in=self.dim_in,
                                             d_v_in=self.dim_in + 2,
                                             d_q_out=self.dim_z,
                                             d_k_out=self.dim_z,
                                             d_v_out=self.dim_z)


        # calculate A
        self.gamma_1 = nn.Parameter(torch.ones(1, 1))
        self.gamma_2 = nn.Parameter(torch.ones(1, 1))
        self.gamma_3 = nn.Parameter(torch.ones(1, 1))
        self.alpha = nn.Parameter(torch.ones(1, 1))
        self.beta = nn.Parameter(torch.zeros(1, 1))

        # transform in Graph
        self.w_1 = nn.Linear(self.dim_in + 2, self.dim_in + 2)
        self.w_2 = nn.Linear(self.dim_in + 2, self.dim_in + 2)


        # higher-order evidence
        # graph view 
        self.out_layer_graph_view = nn.Linear(self.dim_z*2, 5)
        # attribute view 
        self.out_layer_attri_view = nn.Linear(self.dim_in, 5)
    

    # for output mu, v, alpha, beta
    def evidence(self, x):
        return Func.softplus(x)

    def trans(self, gamma1, gamma2, logv, logalpha, logbeta):
        v = self.evidence(logv)
        alpha = self.evidence(logalpha) + 1
        beta = self.evidence(logbeta)
        return gamma1, gamma2, v, alpha, beta
    

    def forward(self, lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay, add_noise=0):
        """
        :param lm_X: feature of landmarks [..., 30]: 14 attribute + 16 measurement
        :param lm_Y: location of landmarks [..., 2]: longitude + latitude
        :param tg_X: feature of targets [..., 30]
        :param tg_Y: location of targets [..., 2]
        :param lm_delay: delay from landmark to the common router [..., 1]
        :param tg_delay: delay from target to the common router [..., 1]
        :return:
        """

  

        N1 = lm_Y.size(0)
        N2 = tg_Y.size(0)
        ones = torch.ones(N1 + N2 + 1).cuda()
        lm_feature = torch.cat((lm_X, lm_Y), dim=1)
        tg_feature_0 = torch.cat((tg_X, torch.zeros(N2, 2).cuda()), dim=1)
        router_0 = torch.mean(lm_feature, dim=0, keepdim=True)
        all_feature_0 = torch.cat((lm_feature, tg_feature_0, router_0), dim=0)

        '''
        star-GNN
        properties:
     
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值