1. Github代码
"""
Created on Mon Sep 28 10:28:06 2020
@author: wb
"""
import torch
import torch.nn as nn
from GCN_models import GCN
from One_hot_encoder import One_hot_encoder
class SSelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SSelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query):
N, T, C = query.shape
values = values.reshape(N, T, self.heads, self.head_dim)
keys = keys.reshape(N, T, self.heads, self.head_dim)
query = query.reshape(N, T, self.heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(query)
energy = torch.einsum("qthd,kthd->qkth", [queries, keys])
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=1)
out = torch.einsum("qkth,kthd->qthd", [attention, values]).reshape(
N, T, self.heads * self.head_dim
)
out = self.fc_out(out)
return out
class TSelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(TSelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query):
N, T, C = query.shape
values = values.reshape(N, T, self.heads, self.head_dim)
keys = keys.reshape(N, T, self.heads, self.head_dim)
query = query.reshape(N, T, self.heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(query)
energy = torch.einsum("nqhd,nkhd->nqkh", [queries, keys])
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=2)
out = torch.einsum("nqkh,nkhd->nqhd", [attention, values]).reshape(
N, T, self.heads * self.head_dim
)
out = self.fc_out(out)
return out
class STransformer(nn.Module):
def __init__(self, embed_size, heads, adj, dropout, forward_expansion):
super(STransformer, self).__init__()
self.adj = adj
self.D_S = nn.Parameter(adj)
self.embed_liner = nn.Linear(adj.shape[0], embed_size)
self.attention = SSelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size),
)
self.gcn = GCN(embed_size, embed_size*2, embed_size, dropout)
self.norm_adj = nn.InstanceNorm2d(1)
self.dropout = nn.Dropout(dropout)
self.fs = nn.Linear(embed_size, embed_size)
self.fg = nn.Linear(embed_size, embed_size)
def forward(self, value, key, query):
N, T, C = query.shape
D_S = self.embed_liner(self.D_S)
D_S = D_S.expand(T, N, C)
D_S = D_S.permute(1, 0, 2)
X_G = torch.Tensor(query.shape[0], 0, query.shape[2])
self.adj = self.adj.unsqueeze(0).unsqueeze(0)
self.adj = self.norm_adj(self.adj)
self.adj = self.adj.squeeze(0).squeeze(0)
for t in range(query.shape[1]):
o = self.gcn(query[ : , t, : ], self.adj)
o = o.unsqueeze(1)
X_G = torch.cat((X_G, o), dim=1)
query = query+D_S
attention = self.attention(value, key, query)
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
U_S = self.dropout(self.norm2(forward + x))
g = torch.sigmoid( self.fs(U_S) + self.fg(X_G) )
out = g*U_S + (1-g)*X_G
t = 1
return out
class TTransformer(nn.Module):
def __init__(self, embed_size, heads, time_num, dropout, forward_expansion):
super(TTransformer, self).__init__()
self.time_num = time_num
self.one_hot = One_hot_encoder(embed_size, time_num)
self.temporal_embedding = nn.Embedding(time_num, embed_size)
self.attention = TSelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size),
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, t):
N, T, C = query.shape
D_T = self.one_hot(t, N, T)
D_T = self.temporal_embedding(torch.arange(0, T))
D_T = D_T.expand(N, T, C)
query = query + D_T
attention = self.attention(value, key, query)
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
class STTransformerBlock(nn.Module):
def __init__(self, embed_size, heads, adj, time_num, dropout, forward_expansion):
super(STTransformerBlock, self).__init__()
self.STransformer = STransformer(embed_size, heads, adj, dropout, forward_expansion)
self.TTransformer = TTransformer(embed_size, heads, time_num, dropout, forward_expansion)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, t):
x1 = self.norm1(self.STransformer(value, key, query) + query)
x2 = self.dropout( self.norm2(self.TTransformer(x1, x1, x1, t) + x1) )
return x2
class Encoder(nn.Module):
def __init__(
self,
embed_size,
num_layers,
heads,
adj,
time_num,
device,
forward_expansion,
dropout,
):
super(Encoder, self).__init__()
self.embed_size = embed_size
self.device = device
self.layers = nn.ModuleList(
[
STTransformerBlock(
embed_size,
heads,
adj,
time_num,
dropout=dropout,
forward_expansion=forward_expansion
)
for _ in range(num_layers)
]
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, t):
out = self.dropout(x)
for layer in self.layers:
out = layer(out, out, out, t)
return out
class Transformer(nn.Module):
def __init__(
self,
adj,
embed_size=64,
num_layers=3,
heads=2,
time_num=288,
forward_expansion=4,
dropout=0,
device="cpu",
):
super(Transformer, self).__init__()
self.encoder = Encoder(
embed_size,
num_layers,
heads,
adj,
time_num,
device,
forward_expansion,
dropout,
)
self.device = device
def forward(self, src, t):
enc_src = self.encoder(src, t)
return enc_src
class STTransformer(nn.Module):
def __init__(
self,
adj,
in_channels = 1,
embed_size = 64,
time_num = 288,
num_layers = 3,
T_dim = 12,
output_T_dim = 3,
heads = 2,
):
super(STTransformer, self).__init__()
self.conv1 = nn.Conv2d(in_channels, embed_size, 1)
self.Transformer = Transformer(
adj,
embed_size,
num_layers,
heads,
time_num
)
self.conv2 = nn.Conv2d(T_dim, output_T_dim, 1)
self.conv3 = nn.Conv2d(embed_size, 1, 1)
self.relu = nn.ReLU()
def forward(self, x, t):
x = x.unsqueeze(0)
input_Transformer = self.conv1(x)
input_Transformer = input_Transformer.squeeze(0)
input_Transformer = input_Transformer.permute(1, 2, 0)
output_Transformer = self.Transformer(input_Transformer, t)
output_Transformer = output_Transformer.permute(1, 0, 2)
output_Transformer = output_Transformer.unsqueeze(0)
out = self.relu(self.conv2(output_Transformer))
out = out.permute(0, 3, 2, 1)
out = self.conv3(out)
out = out.squeeze(0).squeeze(0)
return out
2. 我的代码
"""
Created on Mon Sep 28 10:28:06 2020
@author: wb
"""
import torch
import torch.nn as nn
from GCN_models import GCN
from One_hot_encoder import One_hot_encoder
class SSelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SSelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query):
N, T, C = query.shape
values = values.reshape(N, T, self.heads, self.head_dim)
keys = keys.reshape(N, T, self.heads, self.head_dim)
query = query.reshape(N, T, self.heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(query)
energy = torch.einsum("qthd,kthd->qkth", [queries, keys])
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=1)
out = torch.einsum("qkth,kthd->qthd", [attention, values]).reshape(
N, T, self.heads * self.head_dim
)
out = self.fc_out(out)
return out
class TSelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(TSelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query):
N, T, C = query.shape
values = values.reshape(N, T, self.heads, self.head_dim)
keys = keys.reshape(N, T, self.heads, self.head_dim)
query = query.reshape(N, T, self.heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(query)
energy = torch.einsum("nqhd,nkhd->nqkh", [queries, keys])
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=2)
out = torch.einsum("nqkh,nkhd->nqhd", [attention, values]).reshape(
N, T, self.heads * self.head_dim
)
out = self.fc_out(out)
return out
class STransformer(nn.Module):
def __init__(self, embed_size, heads, adj, dropout, forward_expansion, time_num):
super(STransformer, self).__init__()
self.adj = adj
self.D_S = nn.Parameter(adj)
self.embed_liner = nn.Linear(adj.shape[0], embed_size)
self.temporal_embedding = nn.Embedding(time_num, embed_size)
self.attention = SSelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size),
)
self.gcn = GCN(embed_size, embed_size * 2, embed_size, dropout)
self.norm_adj = nn.InstanceNorm2d(1)
self.dropout = nn.Dropout(dropout)
self.fs = nn.Linear(embed_size, embed_size)
self.fg = nn.Linear(embed_size, embed_size)
def forward(self, value, key, query):
X_S = query
N, T, C = query.shape
D_S = self.embed_liner(self.D_S)
D_S = D_S.expand(T, N, C)
D_S = D_S.permute(1, 0, 2)
D_T = self.temporal_embedding(torch.arange(0, T))
D_T = D_T.expand(N, T, C)
X_G = torch.Tensor(query.shape[0], 0, query.shape[2])
self.adj = self.adj.unsqueeze(0).unsqueeze(0)
self.adj = self.norm_adj(self.adj)
self.adj = self.adj.squeeze(0).squeeze(0)
for t in range(query.shape[1]):
o = self.gcn(query[:, t, :], self.adj)
o = o.unsqueeze(1)
X_G = torch.cat((X_G, o), dim=1)
X_tildeS = X_S + D_S + D_T
query = key = value = X_tildeS
M_S = self.attention(value, key, query)
M_S = self.dropout(self.norm1(M_S + query))
M_tilderS = X_tildeS + M_S
forward = self.feed_forward(M_tilderS)
U_S = self.dropout(self.norm2(forward + M_tilderS))
g = torch.sigmoid(self.fs(U_S) + self.fg(X_G))
out = g * U_S + (1 - g) * X_G
return out
class TTransformer(nn.Module):
def __init__(self, embed_size, heads, time_num, dropout, forward_expansion):
super(TTransformer, self).__init__()
self.time_num = time_num
self.temporal_embedding = nn.Embedding(time_num, embed_size)
self.attention = TSelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size),
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, t):
X_T = query
N, T, C = query.shape
D_T = self.temporal_embedding(torch.arange(0, T))
D_T = D_T.expand(N, T, C)
X_tildeT = X_T + D_T
M_T = self.attention(X_tildeT, X_tildeT, X_tildeT)
M_tildeT = self.dropout(self.norm1(M_T + X_tildeT))
forward = self.feed_forward(M_tildeT)
U_T = self.dropout(self.norm2(forward + M_tildeT))
return U_T
class STTransformerBlock(nn.Module):
def __init__(self, embed_size, heads, adj, time_num, dropout, forward_expansion):
super(STTransformerBlock, self).__init__()
self.STransformer = STransformer(embed_size, heads, adj, dropout, forward_expansion, time_num)
self.TTransformer = TTransformer(embed_size, heads, time_num, dropout, forward_expansion)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, t):
X_S = query
Y_S = self.norm1(self.STransformer(X_S, X_S, X_S) + X_S)
X_T = Y_S + X_S
Y_T = self.dropout(self.norm2(self.TTransformer(X_T, X_T, X_T, t) + X_T))
return Y_T
class Encoder(nn.Module):
def __init__(
self,
embed_size,
num_layers,
heads,
adj,
time_num,
device,
forward_expansion,
dropout,
):
super(Encoder, self).__init__()
self.embed_size = embed_size
self.device = device
self.layers = nn.ModuleList(
[
STTransformerBlock(
embed_size,
heads,
adj,
time_num,
dropout=dropout,
forward_expansion=forward_expansion
)
for _ in range(num_layers)
]
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, t):
out = self.dropout(x)
for layer in self.layers:
out = layer(out, out, out, t)
return out
class Transformer(nn.Module):
def __init__(
self,
adj,
embed_size=64,
num_layers=3,
heads=2,
time_num=288,
forward_expansion=4,
dropout=0,
device="cpu",
):
super(Transformer, self).__init__()
self.encoder = Encoder(
embed_size,
num_layers,
heads,
adj,
time_num,
device,
forward_expansion,
dropout,
)
self.device = device
def forward(self, src, t):
enc_src = self.encoder(src, t)
return enc_src
class STTransformer(nn.Module):
def __init__(
self,
adj,
in_channels=1,
embed_size=64,
time_num=288,
num_layers=3,
T_dim=12,
output_T_dim=3,
heads=2,
):
super(STTransformer, self).__init__()
self.conv1 = nn.Conv2d(in_channels, embed_size, 1)
self.Transformer = Transformer(
adj,
embed_size,
num_layers,
heads,
time_num
)
self.conv2 = nn.Conv2d(T_dim, output_T_dim, 1)
self.conv3 = nn.Conv2d(embed_size, 1, 1)
self.relu = nn.ReLU()
def forward(self, x, t):
x = x.unsqueeze(0)
input_Transformer = self.conv1(x)
input_Transformer = input_Transformer.squeeze(0)
input_Transformer = input_Transformer.permute(1, 2, 0)
output_Transformer = self.Transformer(input_Transformer, t)
output_Transformer = output_Transformer.permute(1, 0, 2)
output_Transformer = output_Transformer.unsqueeze(0)
out = self.relu(self.conv2(output_Transformer))
out = out.permute(0, 3, 2, 1)
out = self.conv3(out)
out = out.squeeze(0).squeeze(0)
return out