目录
摘要
本周阅读了一篇关于使用GNN进行时间序列预测的论文。该论文模型的主要实现了在捕获时间序列不同尺度上的时间依赖关系外,还捕获了在不同尺度上变量之间的内部联系,例如空间上的依赖关系。同时还编写了论文提供的相关代码,了解了论文模型的定义。
ABSTRACT
This week, We read a paper on using Graph Neural Networks (GNN) for time series prediction. The main implementation of the model in the paper not only captures temporal dependencies at different scales in time series but also captures internal relationships between variables at different scales, such as spatial dependencies. Additionally, We wrote code based on the provided paper, gaining a better understanding of the model's definition and implementation.
1 论文标题
Multi-Scale Adaptive Graph Neural Network for Multivariate Time Series Forecasting
2 论文摘要
多变量时间序列(MTS)预测在智能应用的自动化和优化中扮演着重要角色。这是一项具有挑战性的任务,因为我们需要考虑到复杂的变量内部依赖关系和变量之间的依赖关系。现有研究仅通过单个变量之间的依赖来学习时间模式。然而,在许多实际应用中存在着多尺度时间模式,单一变量之间的依赖使得模型更倾向于学习突出且共享的时间模式。本文提出了一个名为多尺度自适应图神经网络(MAGNN)来解决上述问题。MAGNN利用多尺度金字塔网络在不同时间尺度上保持潜在的时间依赖性。由于不同时间尺度下变量之间可能存在不同类型的依赖关系,因此设计了一个自适应图学习模块来推断特定于尺度的变量间依赖关系,而无需预先定义先验知识。针对多尺度特征表示和特定尺度下的变量间依赖关系,引入了多尺度时间图神经网络以联合建模内部和外部依赖关系。随后,我们开发了一个融合模块来有效促进不同时间尺度之间融合,并自动捕捉贡献重要性较高的时间模式。在六个真实数据集上进行实验证明,在各种设置下MAGNN都优于最先进方法。
3 过去方案
多变量时间序列(Multivariate time series, MTS)广泛存在于各种实际场景中,例如城市交通流量、股票市场股票价格以及城市街区家庭用电等。MTS预测是一种基于历史观测时间序列组预测未来趋势的方法,在近年来得到了广泛研究。它具有广泛应用,例如根据预测的交通流量提前规划更优行驶路线,通过对近期股票市场进行预测设计投资策略。准确的MTS预测是一项具有挑战性任务,因为需要同时考虑变量内依赖关系(即一个时间序列内的时间依赖关系)和变量间依赖关系(即单个变量的预测值受其他变量影响)。为解决这一问题,传统方法如向量自回归(VAR)、时间正则化矩阵分解(TRMF)、向量自回归移动平均(VARMA)和高斯过程(GP)往往依赖严格平稳假设,并无法捕捉变量之间非线性依赖关系。深度神经网络在建模非平稳和非线性依赖关系方面表现出卓越性能。特别是循环神经网络(RNN)的两种变体——长短期记忆网络(LSTM)和门控循环单元(GRU),以及时间卷积网络(TCN)在时序建模方面取得显著成果。现有研究引入跳跃连接、注意力机制和记忆网络等策略来同时捕捉长期与短期时序依赖关系。然而这些研究主要集中于时序依赖建模,并将MTS输入视为向量处理,假设单个变量的预测值受所有其他变量影响,这并不合理也难以满足实际应用需求。例如,在街道交通流量中邻近街道对其影响较大而远处街道影响较小。因此明确地对成对变量间的依赖关系进行建模至关重要。
图(Graph)是一种抽象数据类型,用于表示节点之间的关系。图神经网络(GNN)能够有效地捕获节点的高层表示,并利用成对依赖关系,被认为是处理图数据的一种有效方法。从图建模角度考虑MTS预测问题时,可以将MTS中的变量视为图中的节点,而变量之间的成对依赖关系则可看作边。最近有研究利用了图神经网络(GNN)来建模MTS,并充分利用了丰富的结构信息(即特征节点和加权边)。这些研究通过堆叠GNN和时间卷积模块来学习时间模式,并取得了良好结果。然而,现有研究仅考虑单一时间尺度上的时序依赖关系,在反映现实世界中许多场景变化方面存在不足。
首先,在事实上真实MTS中隐藏着更复杂的时间模式,包括每天、每周、每月以及其他特定周期性模式等。例如,在下图中展示了4户居民两周内电力消耗情况,其中存在着短期和长期重复模式(即每天和每周)。这些多尺度时间模式为MTS建模提供了丰富信息;如果我们在不同时间尺度上分别学习时间模式并直接连接它们以获取最终表示,则无法捕获跨尺度关系或者注意到具有贡献性的时间模式。因此,在准确预测MTS方面需要学习一种能够全面反映各种多尺度时间模式特征表达方式。
其次,现有研究通过学习共享邻接矩阵来表示变量间丰富的依赖关系,使得模型在学习一类突出的共享时态模式时存在偏差。事实上,不同类型的时态模式往往受到不同变量间依赖关系的影响,在建模不同的时态模式时需要区分变量间依赖关系。例如,在对一个家庭的短期电力消费模式进行建模时,可能需要更多地关注其邻居的电力消费。因为短期模式的动态往往受到共同事件的影响,例如输电线路故障会降低街区的电力消耗,而突然出现寒冷天气会增加电力消耗。在对一个家庭的长期电力消费模式进行建模时,则可能需要更多地关注具有相似生活习惯(如工作和睡眠时间) 的家庭,因为这些家庭每天和每周都有相似时间规律。因此,在对这些多尺度时间模式进行建模时,需要充分考虑复杂变量间依赖关系。
4 论文方案
本文提出一种用于MTS预测的通用框架——多尺度自适应图神经网络(MAGNN),以解决上述问题。引入多尺度金字塔网络,对不同时间尺度的时间序列进行分层分解。然后,设计了一个自适应图学习模块来自动推断端到端框架中特定尺度的图结构,能够充分探索不同时间尺度下丰富且隐式的变量间依赖关系。接着,在框架中融入多尺度时序图神经网络,对每个时间尺度上的变量内依赖关系和变量间依赖关系进行建模。最后,设计了一个尺度融合模块,以自动考虑特定尺度表示的重要性,并捕获跨尺度的相关性。总而言之,论文的贡献如下:
① 提出了MAGNN,学习一种可以全面反映多尺度时间模式和特定尺度变量间依赖关系的时间表示。
② 设计了一个自适应图学习模块来探索不同时间尺度下丰富和隐式的变量间依赖关系,以及一个尺度融合模块来促进跨这些特定尺度的时间表示的协作,并自动捕获贡献的时间模式的重要性。
③ 在6个真实的MTS基准数据集上进行了广泛的实验。实验结果表明,该方法的性能优于现有的一些方法。
4.1 问题描述
本文主要研究MTS预测。在形式上,给定一个输入的时间序列数据,其中表示时间步t的值,N是变量维度,表示第t个时间步的第i个变量的值,MTS预测的目的是预测未来在时间步t+h的值,其中h表示需要预测的未来时间步数。这个问题可以表述为:。式中F为映射函数,θ表示所有可学习的参数。
然后,存在如下几种关于MTS预测的定义:
定义1:MTS数据用图结构表示。图被定义为G=(V,E),其中V表示节点集,|V|=N,E是边集。假设,将第i个变量视为第i个节点,的值是的特征,每条边的都表明和之间存在变量间的依赖关系。
定义2:加权邻接矩阵。图的加权邻接矩阵是一种用于存储边权重的数学表示方法,其中如果,;如果,。对于没有任何先验知识的纯MTS数据,需要学习多图的加权邻接矩阵来表示丰富且隐式的变量间依赖关系。据此,MTS预测公式可修改为:。其中表示能被GNN用于MTS预测的图集。
4.2 图神经网络(GNN)
图神经网络(GNN)是一类应用于图的深度神经网络。图可以是不规则的,图中无序节点的大小可能是可变的,图中的节点可能有不同数量的邻居。GNN可以很容易地在图域中计算,这克服CNN的局限性。根据实现原理,图神经网络可以分为两类:基于谱的方法和基于空间的方法。基于谱的方法从图信号处理的角度出发,通过引入filter来定义图卷积。图卷积操作可以解释为从图信号中去除噪声。基于空间的方法通过信息传播定义图卷积,结合中心节点的表示及其邻居节点的表示,从而获得该节点的更新表示。论文方法中应用的图卷积操作可以定义为:。
其中G=(V,E,A)是一个带加权邻接矩阵的图,x是节点的表示,σ是一个激活函数,θ是可学习的参数矩阵,是具有自连接的邻接矩阵,是的对角度矩阵,。通过将图卷积操作多层堆叠,可以聚合多阶的邻居信息。
多尺度GNN,又称分层GNN,通常在细粒度图的基础上分层构建粗粒度图。MAGNN关注时间维度的尺度,与一般的多尺度GNN非常不同,后者主要关注空间维度的尺度。MAGNN引入了一个多尺度金字塔网络,将原始时间序列转换为从较小尺度到较大尺度的特征表示,在该网络上,它学习每个尺度下具有相同大小的特定尺度图,并对每个图使用公式定义的基本GNN。
4.3 论文模型
论文模型MAGNN的框架如下所示,该框架由4个主要部分组成:a)多尺度金字塔网络,以保留不同时间尺度下的时序层次结构;B)自适应图学习模块,自动推断变量间的依赖关系;C)多尺度时序图神经网络,用于捕获各种特定尺度的时序模式;D)尺度融合模块,有效促进跨时间尺度的融合。
MAGNN框架的4个主要部分组成具体作用如下:(a)两个并行的卷积神经网络和每层的逐点相加将特征表示从较小尺度分层变换到较大尺度;(b)自适应图学习模块将节点嵌入和尺度嵌入作为输入,并输出特定尺度的邻接矩阵;(c)将每个尺度特定的特征表示和邻接矩阵输入到时序图神经网络(TGNN)中,以获得尺度特定的表示。(d)加权融合特定尺度表示以捕获贡献的时间模式。最终的多尺度表示被送入包括两个卷积神经网络的输出模块以获得预测值。
5 相关代码
MAGNN模型定义如下:
from layer import *
# from AGCRN import *
import torch
class magnn(nn.Module):
def __init__(self, gcn_depth, num_nodes, device, dropout=0.3, subgraph_size=20, node_dim=40, conv_channels=32, gnn_channels=32, scale_channels=16, end_channels=128, seq_length=12, in_dim=1, out_dim=12, layers=3, propalpha=0.05, tanhalpha=3, single_step=True):
super(magnn, self).__init__()
self.num_nodes = num_nodes
self.dropout = dropout
self.device = device
self.single_step = single_step
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.scale_convs = nn.ModuleList()
self.gconv1 = nn.ModuleList()
self.gconv2 = nn.ModuleList()
self.norm = nn.ModuleList()
self.seq_length = seq_length
self.layer_num = layers
self.gc = graph_constructor(num_nodes, subgraph_size, node_dim, self.layer_num, device)
if self.single_step:
self.kernel_set = [7, 6, 3, 2]
else:
self.kernel_set = [3, 2, 2]
self.scale_id = torch.autograd.Variable(torch.randn(self.layer_num, device=self.device), requires_grad=True)
# self.scale_id = torch.arange(self.layer_num).to(device)
self.lin1 = nn.Linear(self.layer_num, self.layer_num)
self.idx = torch.arange(self.num_nodes).to(device)
self.scale_idx = torch.arange(self.num_nodes).to(device)
self.scale0 = nn.Conv2d(in_channels=in_dim, out_channels=scale_channels, kernel_size=(1, self.seq_length), bias=True)
self.multi_scale_block = multi_scale_block(in_dim, conv_channels, self.num_nodes, self.seq_length, self.layer_num, self.kernel_set)
# self.agcrn = nn.ModuleList()
length_set = []
length_set.append(self.seq_length-self.kernel_set[0]+1)
for i in range(1, self.layer_num):
length_set.append( int( (length_set[i-1]-self.kernel_set[i])/2 ) )
for i in range(self.layer_num):
"""
RNN based model
"""
# self.agcrn.append(AGCRN(num_nodes=self.num_nodes, input_dim=conv_channels, hidden_dim=scale_channels, num_layers=1) )
self.gconv1.append(mixprop(conv_channels, gnn_channels, gcn_depth, dropout, propalpha))
self.gconv2.append(mixprop(conv_channels, gnn_channels, gcn_depth, dropout, propalpha))
self.scale_convs.append(nn.Conv2d(in_channels=conv_channels,
out_channels=scale_channels,
kernel_size=(1, length_set[i])))
self.gated_fusion = gated_fusion(scale_channels, self.layer_num)
# self.output = linear(self.layer_num*self.hidden_dim, out_dim)
self.end_conv_1 = nn.Conv2d(in_channels=scale_channels,
out_channels=end_channels,
kernel_size=(1,1),
bias=True)
self.end_conv_2 = nn.Conv2d(in_channels=end_channels,
out_channels=out_dim,
kernel_size=(1,1),
bias=True)
def forward(self, input, idx=None):
seq_len = input.size(3)
assert seq_len==self.seq_length, 'input sequence length not equal to preset sequence length'
scale = self.multi_scale_block(input, self.idx)
# self.scale_weight = self.lin1(self.scale_id)
self.scale_set = [1, 0.8, 0.6, 0.5]
adj_matrix = self.gc(self.idx, self.scale_idx, self.scale_set)
outputs = self.scale0(F.dropout(input, self.dropout, training=self.training))
out = []
out.append(outputs)
for i in range(self.layer_num):
"""
RNN-based model
"""
# output = self.agcrn[i](scale[i].permute(0, 3, 2, 1), adj_matrix) # B T N D
# output = output.permute(0, 3, 2, 1)
output = self.gconv1[i](scale[i], adj_matrix[i])+self.gconv2[i](scale[i], adj_matrix[i].transpose(1,0))
scale_specific_output = self.scale_convs[i](output)
out.append(scale_specific_output)
# concatenate
# outputs = outputs + scale_specific_output
# mean-pooling
# outputs = torch.mean(torch.stack(out), dim=0)
out0 = torch.cat(out, dim=1)
out1 = torch.stack(out, dim = 1)
if self.single_step:
outputs = self.gated_fusion(out0, out1)
x = F.relu(outputs)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x)
return x, adj_matrix
网络层定义如下:
from __future__ import division
import torch
import torch.nn as nn
from torch.nn import init
import numbers
import torch.nn.functional as F
import numpy as np
class nconv(nn.Module):
def __init__(self):
super(nconv,self).__init__()
def forward(self,x, A):
x = torch.einsum('ncvl,vw->ncwl',(x,A))
return x.contiguous()
class dy_nconv(nn.Module):
def __init__(self):
super(dy_nconv,self).__init__()
def forward(self,x, A):
x = torch.einsum('ncvl,nvwl->ncwl',(x,A))
return x.contiguous()
class linear(nn.Module):
def __init__(self,c_in,c_out,bias=True):
super(linear,self).__init__()
self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias)
def forward(self,x):
return self.mlp(x)
class layer_block(nn.Module):
def __init__(self, c_in, c_out, k_size):
super(layer_block, self).__init__()
self.conv_output = nn.Conv2d(c_in, c_out, kernel_size=(1, 1), stride=(1, 2))
self.conv_output1 = nn.Conv2d(c_in, c_out, kernel_size=(1, k_size), stride=(1, 1), padding=(0, int( (k_size-1)/2 ) ) )
self.output = nn.MaxPool2d(kernel_size=(1,3), stride=(1,2), padding=(0,1))
self.conv_output1 = nn.Conv2d(c_in, c_out, kernel_size=(1, k_size), stride=(1, 1) )
self.output = nn.MaxPool2d(kernel_size=(1,3), stride=(1,2))
self.relu = nn.ReLU()
def forward(self, input):
conv_output = self.conv_output(input) # shape (B, D, N, T)
conv_output1 = self.conv_output1(input)
output = self.output(conv_output1)
return self.relu( output+conv_output[...,-output.shape[3]:] )
# return self.relu( conv_output )
class multi_scale_block(nn.Module):
def __init__(self, c_in, c_out, num_nodes, seq_length, layer_num, kernel_set, layer_norm_affline=True):
super(multi_scale_block, self).__init__()
self.seq_length = seq_length
self.layer_num = layer_num
self.norm = nn.ModuleList()
self.scale = nn.ModuleList()
for i in range(self.layer_num):
self.norm.append(nn.BatchNorm2d(c_out, affine=False))
# # self.norm.append(LayerNorm((c_out, num_nodes, int(self.seq_length/2**i)),elementwise_affine=layer_norm_affline))
# self.norm.append(LayerNorm((c_out, num_nodes, length_set[i]),elementwise_affine=layer_norm_affline))
self.start_conv = nn.Conv2d(c_in, c_out, kernel_size=(1, 1), stride=(1, 1))
self.scale.append(nn.Conv2d(c_out, c_out, kernel_size=(1, kernel_set[0]), stride=(1, 1)))
for i in range(1, self.layer_num):
self.scale.append(layer_block(c_out, c_out, kernel_set[i]))
def forward(self, input, idx): # input shape: B D N T
self.idx = idx
scale = []
scale_temp = input
scale_temp = self.start_conv(scale_temp)
# scale.append(scale_temp)
for i in range(self.layer_num):
scale_temp = self.scale[i](scale_temp)
# scale_temp = self.norm[i](scale_temp)
# scale_temp = self.norm[i](scale_temp, self.idx)
# scale.append(scale_temp[...,-self.k:])
scale.append(scale_temp)
return scale
class top_down_path(nn.Module):
def __init__(self, c_in, c_out_1, c_out_2, c_out_3, c_out_4):
super(top_down_path, self).__init__()
self.down1 = nn.Conv2d(c_in, c_out_1, kernel_size=(1, 1), stride=(1, 1))
self.down2 = nn.Conv2d(c_out_1, c_out_2, kernel_size=(1, 7), stride=(1, 2), padding=(0, 2))
self.down3 = nn.Conv2d(c_out_2, c_out_3, kernel_size=(1, 6), stride=(1, 2))
self.down4 = nn.Conv2d(c_out_3, c_out_4, kernel_size=(1, 3), stride=(1, 2))
self.up3 = nn.ConvTranspose2d(c_out_4, c_out_3, kernel_size=(1,3), stride=(1,2))
# self.up3 = nn.MaxUnpool2d()
self.up2 = nn.ConvTranspose2d(c_out_3, c_out_2, kernel_size=(1,6), stride=(1,2), output_padding=(0,1))
self.up1 = nn.ConvTranspose2d(c_out_2, c_out_1, kernel_size=(1,7), stride=(1,2))
def forward(self, input):
down_1 = self.down1(input)
down_2 = self.down2(down_1)
down_3 = self.down3(down_2)
down_4 = self.down4(down_3)
up_3 = self.up3(down_4)
output_3 = down_3 + up_3
up_2 = self.up2(output_3)
output_2 = down_2 + up_2
up_1 = self.up3(output_2)
output_1 = down_1[:,:,:,1:] + up_1
return down_4, output_3, output_2, output_1
class gated_fusion(nn.Module):
def __init__(self, skip_channels, layer_num, ratio=1):
super(gated_fusion, self).__init__()
# self.reduce = torch.mean(x,dim=2,keepdim=True)
self.dense1 = nn.Linear(in_features=skip_channels*(layer_num+1), out_features=(layer_num+1)*ratio, bias=False)
self.dense2 = nn.Linear(in_features=(layer_num+1)*ratio, out_features=(layer_num+1), bias=False)
def forward(self, input1, input2):
se = torch.mean(input1, dim=2, keepdim=False)
se = torch.squeeze(se)
se = F.relu(self.dense1(se))
se = F.sigmoid(self.dense2(se))
se = torch.unsqueeze(se, -1)
se = torch.unsqueeze(se, -1)
se = torch.unsqueeze(se, -1)
x = torch.mul(input2, se)
x = torch.mean(x, dim=1, keepdim=False)
return x
class prop(nn.Module):
def __init__(self,c_in,c_out,gdep,dropout,alpha):
super(prop, self).__init__()
self.nconv = nconv()
self.mlp = linear(c_in,c_out)
self.gdep = gdep
self.dropout = dropout
self.alpha = alpha
def forward(self,x,adj):
adj = adj + torch.eye(adj.size(0)).to(x.device)
d = adj.sum(1)
h = x
dv = d
a = adj / dv.view(-1, 1)
for i in range(self.gdep):
h = self.alpha*x + (1-self.alpha)*self.nconv(h,a)
ho = self.mlp(h)
return ho
class mixprop(nn.Module):
def __init__(self,c_in,c_out,gdep,dropout,alpha):
super(mixprop, self).__init__()
self.nconv = nconv()
self.mlp = linear((gdep+1)*c_in,c_out)
self.gdep = gdep
self.dropout = dropout
self.alpha = alpha
def forward(self,x,adj):
adj = adj + torch.eye(adj.size(0)).to(x.device)
d = adj.sum(1)
h = x
out = [h]
a = adj / d.view(-1, 1)
for i in range(self.gdep):
h = self.alpha*x + (1-self.alpha)*self.nconv(h,a)
out.append(h)
ho = torch.cat(out,dim=1)
ho = self.mlp(ho)
return ho
class graph_constructor(nn.Module):
def __init__(self, nnodes, k, dim, layer_num, device, alpha=3):
super(graph_constructor, self).__init__()
self.nnodes = nnodes
self.layers = layer_num
self.emb1 = nn.Embedding(nnodes, dim)
self.emb2 = nn.Embedding(nnodes, dim)
self.lin1 = nn.ModuleList()
self.lin2 = nn.ModuleList()
for i in range(layer_num):
self.lin1.append(nn.Linear(dim,dim))
self.lin2.append(nn.Linear(dim,dim))
self.device = device
self.k = k
self.dim = dim
self.alpha = alpha
def forward(self, idx, scale_idx, scale_set):
nodevec1 = self.emb1(idx)
nodevec2 = self.emb2(idx)
adj_set = []
for i in range(self.layers):
nodevec1 = torch.tanh(self.alpha*self.lin1[i](nodevec1*scale_set[i]))
nodevec2 = torch.tanh(self.alpha*self.lin2[i](nodevec2*scale_set[i]))
a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0))
adj0 = F.relu(torch.tanh(self.alpha*a))
mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device)
mask.fill_(float('0'))
s1,t1 = adj0.topk(self.k,1)
mask.scatter_(1,t1,s1.fill_(1))
# print(mask)
adj = adj0*mask
adj_set.append(adj)
return adj_set
class graph_constructor_full(nn.Module):
def __init__(self, nnodes, k, dim, layer_num, device, alpha=3):
super(graph_constructor_full, self).__init__()
self.nnodes = nnodes
self.layers = layer_num
self.emb1 = nn.Embedding(nnodes, dim)
self.emb2 = nn.Embedding(nnodes, dim)
self.lin1 = nn.ModuleList()
self.lin2 = nn.ModuleList()
for i in range(self.layers):
self.lin1.append(nn.Linear(dim,dim))
self.lin2.append(nn.Linear(dim,dim))
self.device = device
self.k = k
self.dim = dim
self.alpha = alpha
self.static_feat = static_feat
def forward(self, idx, scale_idx, scale_set):
nodevec1 = self.emb1(idx)
nodevec2 = self.emb2(idx)
adj_set = []
for i in range(self.layers):
nodevec1 = torch.tanh(self.alpha*self.lin1[i](nodevec1*scale_set[i]))
nodevec2 = torch.tanh(self.alpha*self.lin2[i](nodevec2*scale_set[i]))
a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0))
adj0 = F.relu(torch.tanh(self.alpha*a))
adj_set.append(adj0)
return adj_set
class graph_constructor_one(nn.Module):
def __init__(self, nnodes, k, dim, layer_num, device, alpha=3, static_feat=None):
super(graph_constructor_one, self).__init__()
self.nnodes = nnodes
self.layers = layer_num
self.emb1 = nn.Embedding(nnodes, dim)
self.emb2 = nn.Embedding(nnodes, dim)
self.lin1 = nn.ModuleList()
self.lin2 = nn.ModuleList()
self.lin1 = nn.Linear(dim,dim)
self.lin2 = nn.Linear(dim,dim)
self.device = device
self.k = k
self.dim = dim
self.alpha = alpha
self.static_feat = static_feat
def forward(self, idx, scale_idx, scale_set):
nodevec1 = self.emb1(idx)
nodevec2 = self.emb2(idx)
adj_set = []
nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))
a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0))
adj0 = F.relu(torch.tanh(self.alpha*a))
mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device)
mask.fill_(float('0'))
s1,t1 = adj0.topk(self.k,1)
mask.scatter_(1,t1,s1.fill_(1))
adj = adj0*mask
return adj