- 【ASTGCN】模型解读(torch)之模型框架(三)
在上一篇中讲解了一个分量结构的所有代码,在本篇中即将讲解三个分量合并的模型框架。 - 代码见【ASTGCN】
因为本篇与前3篇类似,但不是同一个代码来源,思想结构类似。
为了搭配前面的代码,本章会作适量的修改。
注意力机制
一、Temporal_Attention_layer
X(B,N,F,T)–转置–>(B,T,F,N)–
×
\times
×U1(N,)–>(B,T,F,)–
×
\times
×U2(F,N)–>lhs(B,T,N)
X(B,N,F,T)–转置–>(B,N,T,F)–
×
\times
×U3(F,)–>rhs(B,N,T)
product=lhs
×
\times
×rhs,(B,T,T)
class Temporal_Attention_layer(nn.Module):
def __init__(self, num_of_vertices, num_of_features, num_of_timesteps):
super(Temporal_Attention_layer, self).__init__()
global device
self.U_1 = torch.randn(num_of_vertices, requires_grad=True).to(device)
self.U_2 = torch.randn(num_of_features, num_of_vertices, requires_grad=True).to(device)
self.U_3 = torch.randn(num_of_features, requires_grad=True).to(device)
self.b_e = torch.randn(1, num_of_timesteps, num_of_timesteps, requires_grad=True).to(device)
self.V_e = torch.randn(num_of_timesteps, num_of_timesteps, requires_grad=True).to(device)
def forward(self, x):
lhs = torch.matmul(torch.matmul(x.permute(0, 3, 2, 1), self.U_1),self.U_2)
rhs = torch.matmul(x.permute((0, 1, 3, 2)), self.U_3)
product = torch.matmul(lhs, rhs)
E = torch.matmul(self.V_e, torch.sigmoid(product + self.b_e))
# normailzation
E = E - torch.max(E, 1, keepdim=True)[0]
exp = torch.exp(E)
E_normalized = exp / torch.sum(exp, 1, keepdim=True)
return E_normalized
二、Spatial_Attention_layer
x(B,N,F,T)–
×
\times
×W1(T,)–>(B,N,F)—
×
\times
×W2(F,T)–>lhs(B,N,T)
x(B,N,F,T)—转置–>x(B,T,N,F)—
×
\times
×W3(F,)–>rhs(B,T,N)
product=lhs
×
\times
×rhs:(B,N,N)
class Spatial_Attention_layer(nn.Module):
def __init__(self, num_of_vertices, num_of_features, num_of_timesteps):
super(Spatial_Attention_layer, self).__init__()
global device
self.W_1 = torch.randn(num_of_timesteps, requires_grad=True).to(device)
self.W_2 = torch.randn(num_of_features, num_of_timesteps, requires_grad=True).to(device)
self.W_3 = torch.randn(num_of_features, requires_grad=True).to(device)
self.b_s = torch.randn(1, num_of_vertices, num_of_vertices, requires_grad=True).to(device)
self.V_s = torch.randn(num_of_vertices, num_of_vertices, requires_grad=True).to(device)
def forward(self, x):
lhs = torch.matmul(torch.matmul(x, self.W_1), self.W_2)
rhs = torch.matmul(x.permute((0, 3, 1, 2)), self.W_3)
product = torch.matmul(lhs, rhs)
S = torch.matmul(self.V_s, torch.sigmoid(product + self.b_s))
# normalization
S_normalized = F.softmax(S, dim=1)
return S_normalized
切比雪夫卷积
class cheb_conv_with_SAt(nn.Module):
# def __init__(self, num_of_filters, K, cheb_polynomials, num_of_features):
def __init__(self, K, cheb_polynomials, in_channels, out_channels):
super(cheb_conv_with_SAt, self).__init__()
self.K = K
self.cheb_polynomials = cheb_polynomials
self.in_channels = in_channels
self.out_channels = out_channels
self.DEVICE = cheb_polynomials[0].device
self.Theta = nn.ParameterList([nn.Parameter(torch.FloatTensor(in_channels, out_channels).to(self.DEVICE)) for _ in range(K)])
def forward(self, x, spatial_attention):
batch_size, num_of_vertices, in_channels, num_of_timesteps = x.shape
outputs = []
for time_step in range(num_of_timesteps):
# shape is (batch_size, V, F)
graph_signal = x[:, :, :, time_step]
output = torch.zeros(batch_size, num_of_vertices,
self.num_of_filters).to(self.DEVICE) # do we need to set require_grad=True?
for k in range(self.K):
# shape of T_k is (V, V)
T_k = self.cheb_polynomials[k]
T_k_with_at = T_k * spatial_attention
theta_k = self.Theta[k]
rhs = torch.matmul(T_k_with_at.permute((0, 2, 1)),
graph_signal)
output = output + torch.matmul(rhs, theta_k)
outputs.append(torch.unsqueeze(output, -1))
return F.relu(torch.cat(outputs, dim=-1))
ASTGCN
一、ASTGCN_block
变量 | 前一篇变量名 | 类型 | 举例 | 用途 |
---|---|---|---|---|
num_for_prediction | nb_predict_step =num_for_predict | int | 12 | 用于预测多长时间的结果,if12,so预测1小时 |
backbone | dict | |||
num_of_vertices | int | 307 | 顶点个数 | |
num_of_features | in_channels | 数据的特征轴的维度F_in | ||
num_of_timesteps | len_input | int | 12 | 数据的时间轴的维度T |
_device | device=DEVICE | str | cpu | 设备 |
backbone:dict
key | value | 含义 |
---|---|---|
K | 3 | 且比学夫不等式的阶 |
num_of_chev_filters | 64 | chev_conv的输出数据的特征 |
num_of_time_filters | 64 | time_conv的输出数据的特征 |
time_conv_kernel_size | 2 | time_conv的kernel? |
cheb_polynomials | adj_mx | 节点的邻接矩阵用于多项式 |
class ASTGCN_block(nn.Module):
def __init__(self,DEVICE, backbone,
num_of_vertices, num_of_features, num_of_timesteps):
"""
Parameters
----------
backbone: dict, should have 6 keys,
"K",
"num_of_chev_filters",
"num_of_time_filters",
"time_conv_kernel_size", # wd: never used?? Actually there is no such key in backbone...
"time_conv_strides",
"cheb_polynomials"
"""
super(ASTGCN_block, self).__init__()
K = backbone['K']
nb_chev_filters = backbone['num_of_chev_filters']
nb_time_filters = backbone['num_of_time_filters']
time_conv_strides = backbone['time_conv_strides']
cheb_polynomials = backbone["cheb_polynomials"]
self.SAt = Spatial_Attention_layer(num_of_vertices,
num_of_features, num_of_timesteps)
self.cheb_conv_SAt = cheb_conv_with_SAt(K=K,
cheb_polynomials=cheb_polynomials,
in_channels=num_of_features,
out_channels=nb_chev_filters
)
self.TAt = Temporal_Attention_layer(num_of_vertices,
num_of_features, num_of_timesteps)
self.time_conv = nn.Conv2d(
in_channels=nb_chev_filters,
out_channels=nb_time_filters,
kernel_size=(1, 3),
stride=(1, time_conv_strides),
padding=(0, 1))
self.residual_conv = nn.Conv2d(
in_channels=num_of_features,
out_channels=nb_time_filters,
kernel_size=(1, 1),
stride=(1, time_conv_strides))
self.ln = nn.LayerNorm(nb_time_filters)
def forward(self, x):
batch_size, num_of_vertices,num_of_features, num_of_timesteps= x.shape
temporal_At = self.TAt(x)#获得时间注意力矩阵
x_TAt = torch.matmul(x.reshape(batch_size, -1, num_of_timesteps),
temporal_At) \
.reshape(batch_size, num_of_vertices,
num_of_features, num_of_timesteps)
spatial_At = self.SAt(x_TAt)#获得空间注意力矩阵
spatial_gcn = self.cheb_conv_SAt(x, spatial_At)# 经过图卷积后
time_conv_output = (self.time_conv(spatial_gcn.permute(0, 2, 1, 3))
.permute(0, 2, 1, 3))# 经过时间卷积
# residual shortcut
x_residual = (self.residual_conv(x.permute(0, 2, 1, 3))
.permute(0, 2, 1, 3))# 采用conv2d进行残差操作
relued = F.relu(x_residual + time_conv_output)
out=self.ln(relued.permute(0, 1, 3, 2)).permute(0, 1, 3, 2)
# out.shape=(B,N,nb_time_filter,T)
return out
二、ASTGCN_submodule
变量 | 前一篇变量名 | 类型 | 举例 | 用途 |
---|---|---|---|---|
num_for_prediction | nb_predict_step =num_for_predict | int | 12 | 用于预测多长时间的结果,if12,so预测1小时 |
backbones | ||||
num_of_vertices | int | 307 | 顶点个数 | |
num_of_features | in_channels | 数据的特征轴的维度F_in | ||
num_of_timesteps | len_input | int | 12 | 数据的时间轴的维度T |
class ASTGCN_submodule(nn.Module):
def __init__(self, DEVICE,backbones,
num_of_vertices, num_of_features, num_of_timesteps,num_for_prediction ):
super(ASTGCN_submodule, self).__init__()
nb_time_filter = backbones[0]['num_of_time_filters']
self.BlockList = nn.ModuleList([ASTGCN_block(DEVICE, backbones[0],
num_of_vertices,num_of_features,num_of_timesteps)])
self.BlockList.extend([ASTGCN_block(DEVICE, backbones[i],
num_of_vertices,nb_time_filter,num_of_timesteps) for i in range(1,len(backbones))])
self.final_conv=nn.Conv2d(in_channels=backbones[-1]['time_conv_strides']
,out_channels=num_for_prediction
,kernel_size=(1, nb_time_filter))
self.DEVICE =DEVICE
self.W = torch.randn(num_of_vertices, num_for_prediction, requires_grad=True).to(device)
self.to(DEVICE)
def forward(self, x):
for block in self.BlockList:
# original x is (B,N,F_in,T) will give (B,N,F_out,T) (32, 307, 1, 12) -> (32, 307, 64, 12)
x = block(x)
module_output = (self.final_conv(x.permute((0, 3, 1, 2)))
[:, :, :, -1].permute((0, 2, 1)))
# module_output*W:(32,307,12)*(307,12)=(32,307,12)是逐元素乘法
return module_output * self.W
三、ASTGCN
变量 | 前一篇变量名 | 类型 | 举例 | 用途 |
---|---|---|---|---|
num_for_prediction | nb_predict_step =num_for_predict | int | 12 | 用于预测多长时间的结果,if12,so预测1小时 |
all_backbones | ||||
num_of_vertices | int | 307 | 顶点个数 | |
num_of_features | in_channels | 数据的特征轴的维度F_in | ||
num_of_timesteps | len_input | int | 12 | 数据的时间轴的维度T |
_device | device=DEVICE | str | cpu | 设备 |
backbone = [
{
"K": K,
"num_of_chev_filters": 64,
"num_of_time_filters": 64,
"time_conv_strides": 2,
"cheb_polynomials": cheb_polynomials
},
{
"K": K,
"num_of_chev_filters": 64,
"num_of_time_filters": 64,
"time_conv_strides": 1,
"cheb_polynomials": cheb_polynomials
}
]
class ASTGCN(nn.Module):
def __init__(self,DEVICE, all_backbones,
num_of_vertices, num_of_features, num_of_timesteps,
num_for_prediction):
"""
Parameters
----------
num_for_prediction: int
how many time steps will be forecasting
all_backbones: list[list],
3 backbones for "hour", "day", "week" submodules.
"week", "day", "hour" (in order)
num_of_vertices: int
The number of vertices in the graph
num_of_features: int
The number of features of each measurement
num_of_timesteps: 2D array, shape=(3, 2)
The timestemps for each time scale (week, day, hour).
Each row is [input_timesteps, output_timesteps].
"""
super(ASTGCN, self).__init__()
if debug_on:
print("ASTGCN model:")
print("num for prediction: ", num_for_prediction)
print("num of vertices: ", num_of_vertices)
print("num of features: ", num_of_features)
print("num of timesteps: ", num_of_timesteps)
self.submodules = nn.ModuleList(
[ASTGCN_submodule(
DEVICE, backbones,
num_of_vertices, num_of_features, num_of_timesteps,
num_for_prediction)
for idx, backbones in enumerate(all_backbones)
])
def forward(self, x_list):
"""
Parameters
----------
x_list: list[torch.tensor], 三个时段的输入数据列表
shape of each element is (batch_size, num_of_vertices,
num_of_features, num_of_timesteps)
Returns
----------
Y_hat: torch.tensor,
shape is (batch_size, num_of_vertices, num_for_prediction)
"""
if debug_on:
for x in x_list:
print('Shape of input to the model:', x.shape)
# 如果输入数据的列表与模型列表长度不符,则报错
if len(x_list) != len(self.submodules):
raise ValueError("num of submodule not equals to "
"length of the input list")
# 确定顶点的个数是一致的
num_of_vertices_set = {i.shape[1] for i in x_list}
if len(num_of_vertices_set) != 1:
raise ValueError("Different num_of_vertices detected! "
"Check if your input data have same size"
"at axis 1.")
# 确定batch_size 在列表中是一致的
batch_size_set = {i.shape[0] for i in x_list}
if len(batch_size_set) != 1:
raise ValueError("Input values must have same batch size!")
submodule_outputs = []
for idx, submodule in enumerate(self.submodules):
submodule_result = submodule(x_list[idx])
submodule_result = torch.unsqueeze(submodule_result, dim=-1)#升维
submodule_outputs.append(submodule_result)
submodule_outputs = torch.cat(submodule_outputs, dim=-1)#最后一维合并
out =torch.sum(submodule_outputs, dim=-1)#按照最后一维求和
return out