【MGT】代码解读之model-MGT

论文解读:《Meta Graph Transformer: A Novel Framework for Spatial–Temporal Traffic Prediction》

代码链接:https://github.com/lonicera-yx/MGT


壹、测试主框架

一、文件目录

在这里插入图片描述
文件夹MGT-main下包含4个子文件夹,一个数据压缩文件,一个main.py等等。
MGT.py位于子文件夹models中,包含了3个def和14个class

二、if 主函数测试MGT

if __name__ == '__main__':
    print(os.getcwd())
    #cfgs = yaml.safe_load(open('cfgs/HZMetro_MGT.yaml'))['model']
    cfgs = yaml.safe_load(open('../cfgs/HZMetro_MGT.yaml'))['model']
    model = MGT(cfgs)

    # dummy data 虚拟数据
    B, P, Q, N, C = 10, 4, 4, 80, 2
    # B:batch_size P:history  Q:feture N:Nodes C:Features
    M = 73, 2 # M is tuple
    eigenmaps_k = 8 #拉普拉斯特征映射降维方法的参数
    n = 3

    inputs = torch.randn(B, P, N, C, dtype=torch.float32)#(10,4,80,2)
    targets = torch.randn(B, Q, N, C, dtype=torch.float32)#(10,4,80,2)

    inputs_time0 = torch.randint(M[0], (B, P), dtype=torch.int64)#(10,4) max_int is 73
    targets_time0 = torch.randint(M[0], (B, Q), dtype=torch.int64)#(10,4)
    inputs_time1 = torch.randint(M[1], (B, P), dtype=torch.int64)#(10,4) max_int is 2
    targets_time1 = torch.randint(M[1], (B, Q), dtype=torch.int64)#(10,4)

    eigenmaps = torch.randn(N, eigenmaps_k, dtype=torch.float32)#(80,8)

    transition_matrices = torch.rand(n, N, N, dtype=torch.float32)#(3,80,80)

    extras = [inputs_time0, targets_time0, inputs_time1, targets_time1]
    statics = {'eigenmaps': eigenmaps, 'transition_matrices': transition_matrices}

    # forward
    outputs1 = model(inputs, targets, *extras, **statics) #*和**见注释1
    outputs2 = model(inputs, None, *extras, **statics)

注释1:
见博文《def 参数 及参数解构 》

贰、MGT

def __init__ 结构搭建

在这里插入图片描述

在原文中有MTG的各种变形(如下),我们不考虑这些,只考经典的MTG
在这里插入图片描述
所以,在MGT下,self.noTE=self.noSE=False.共包含5个层结构:时间嵌入层、空间嵌入层、时空嵌入层、编码器结构、解码器结构。

def forward流程图

流程图中的input是一个,为了美观,所以拆分为两个分别作为输入。

dict
list
B,P
B,Q
B,P,d_m
B,Q,d_m
N,k
N,d_m
B,P,N,d_m
B,Q,N,d_m
B,P,N,C
B,P,N,C
B,Q,N,C
特征映射矩阵
转移矩阵
input0
input1
target0
target1
inputs
inputs
targets
extra
时间嵌入层
z_input
z_target
statics
空间嵌入层
U
时空嵌入层
c_inputs
c_targets
encoder
en-out
encoder

叁、 三个嵌入层TE\SE\STE

1. TE

def _init__
在这里插入图片描述
在init中主要定义了一个不可优化的参数矩阵self.pe和两个层结构。第一个层结构含两个嵌入层,第二个层结构是一个全连接层。如下:
在这里插入图片描述
注释:

  1. self.register_buffer
  2. nn.Embedding

def forward
数据形状和层结构的搭建如下图所示,从而完成数据的时间嵌入.橙色是对于input来说,蓝色是对于target来说的。nn.Embedding,nn.linear等都是固定的层结构。
在这里插入图片描述
代码为:
在这里插入图片描述
注释:

  1. torch.Tensor.expand

2 SE

空间嵌入是将具有空间特征的矩阵进行线性变换即可。
在这里插入图片描述

3 STE

将SE后的(z_inputs,z_targets)和TE后的u,进行扩维,最后经过一个线性变换合并信息。
在这里插入图片描述

注释:

  1. torch.stack,沿一个新维度对输入张量序列进行连接,序列中所有张量应为相同形状;stack 函数返回的结果会新增一个维度,而stack()函数指定的dim参数,就是新增维度的(下标)位置。

肆、Encoder

一、Encoder

def __init__

在这里插入图片描述

def forward

在这里插入图片描述

二、EncoderLayer层

def __init__

类从cfgs中获得的变量,设置的class的属性
在这里插入图片描述
其中包括3个层结构:TSA,SSA和FFN
在这里插入图片描述

伍、时间\空间\时编码自注意力层

1.TSA层

当使用元学习的时候,包含三个层结构:MetaLearner(元学习),LayerNorm(层归一化),Linear(现象变换)
在这里插入图片描述
注释:

  1. nn.LayerNorm,《用法》《实现》

在这里插入图片描述

c=torch.randint(10,(num_weight_matrices, B, P, N, num_heads, d_k, d_model))

注释:

  1. torch.new_full,《案例讲解》,《精彩讲解》
  2. torch.triu, 《详例讲解》

2.SSA

当使用元学习的时候,包含4个结构:Meta_learner(元学习列表),Linear(线性变换),dropout, LayerNorm(层归一化)
在这里插入图片描述
输入数据为:inputs,c_inputs,transition_matrices
其 数据运行如下:
在这里插入图片描述
在这里插入图片描述

3.TEDA

当使用元学习的时候,包含3个层结构:MetaLearner, LayerNorm, Linear
在这里插入图片描述
在这里插入图片描述

陆、Decoder

一、Decoder

def __init__

在这里插入图片描述
在这里插入图片描述

def forward

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

二、DecoderLayer

def __init__

DecoderLayer在cfgs中获得变量和类属性
在这里插入图片描述
查看DecoderLayer中的层结构,我们知道MTG有如下的变体,但在这里我们只考虑MGT.
在这里插入图片描述
MGT中有4个层结构:
在这里插入图片描述
具体的形状如下
TSA层
在这里插入图片描述
SSA层
在这里插入图片描述
TEDA层
在这里插入图片描述
FFN层
在这里插入图片描述

def forward

在这里插入图片描述

捌、其他层

一、MetaLearner层

在这里插入图片描述
MetaLearner包含2个全连接层,形状如下:
在这里插入图片描述

二、FeedForward层

包含两个全连接层(Linear)和层归一化层(LayerNorm)
在这里插入图片描述

三、Projection

在这里插入图片描述

玖、三个多头函数

1. multihead_linear_transform

在这里插入图片描述

2. multihead_temporal_attention

在这里插入图片描述

3. multihead_spatial_attention

在这里插入图片描述

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值