STTN-SpatialTemporalTransformer模型代码

1. Github代码

# -*- coding: utf-8 -*-
"""
Created on Mon Sep 28 10:28:06 2020

@author: wb
"""

import torch
import torch.nn as nn
from GCN_models import GCN
from One_hot_encoder import One_hot_encoder

class SSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SSelfAttention, self).__init__()
        self.embed_size = embed_size  # 64
        self.heads = heads  # 8
        self.head_dim = embed_size // heads  # 8

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"
            
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query):
        N, T, C = query.shape    # 难道这里的C是embed_size?

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, T, self.heads, self.head_dim)  #embed_size维拆成 heads×head_dim
        keys   = keys.reshape(N, T, self.heads, self.head_dim)
        query  = query.reshape(N, T, self.heads, self.head_dim)

        values  = self.values(values)  # (N, T, heads, head_dim)
        keys    = self.keys(keys)      # (N, T, heads, head_dim)
        queries = self.queries(query)  # (N, T, heads, heads_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm

        energy = torch.einsum("qthd,kthd->qkth", [queries, keys])   # 空间self-attention
        # queries shape: (N, T, heads, heads_dim),
        # keys shape: (N, T, heads, heads_dim)
        # energy: (N, N, T, heads)

        # Normalize energy values similarly to seq2seq + attention
        # so that they sum to 1. Also divide by scaling factor for
        # better stability
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=1)  # 在K维做softmax,和为1
        # attention shape: (N, N, T, heads)

        out = torch.einsum("qkth,kthd->qthd", [attention, values]).reshape(
            N, T, self.heads * self.head_dim
        )        
        # attention shape: (N, N, T, heads)
        # values shape: (N, T, heads, heads_dim)
        # out after matrix multiply: (N, T, heads, head_dim), then
        # we reshape and flatten the last two dimensions.

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, T, embed_size)

        return out
    
class TSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(TSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query):
        N, T, C = query.shape

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, T, self.heads, self.head_dim)  # embed_size维拆成 heads×head_dim
        keys   = keys.reshape(N, T, self.heads, self.head_dim)
        query  = query.reshape(N, T, self.heads, self.head_dim)

        values  = self.values(values)  # (N, T, heads, head_dim)
        keys    = self.keys(keys)      # (N, T, heads, head_dim)
        queries = self.queries(query)  # (N, T, heads, heads_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm
        energy = torch.einsum("nqhd,nkhd->nqkh", [queries, keys])   # 时间self-attention
        # queries shape: (N, T, heads, heads_dim),
        # keys shape: (N, T, heads, heads_dim)
        # energy: (N, T, T, heads)
        
        
        # Normalize energy values similarly to seq2seq + attention
        # so that they sum to 1. Also divide by scaling factor for
        # better stability
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=2)  # 在K维做softmax,和为1
        # attention shape: (N, query_len, key_len, heads)

        out = torch.einsum("nqkh,nkhd->nqhd", [attention, values]).reshape(
                N, T, self.heads * self.head_dim
        )
        # attention shape: (N, T, T, heads)
        # values shape: (N, T, heads, heads_dim)
        # out after matrix multiply: (N, T, heads, head_dim), then
        # we reshape and flatten the last two dimensions.

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, T, embed_size)

        return out
    
    
class STransformer(nn.Module):
    def __init__(self, embed_size, heads, adj, dropout, forward_expansion):
        super(STransformer, self).__init__()
        # Spatial Embedding
        self.adj = adj
        self.D_S = nn.Parameter(adj)
        self.embed_liner = nn.Linear(adj.shape[0], embed_size)
        
        self.attention = SSelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )
        
        # 调用GCN
        # input:embed_size;  hidden: embed_size*2;  outpt:embed_size
        self.gcn = GCN(embed_size, embed_size*2, embed_size, dropout)  
        self.norm_adj = nn.InstanceNorm2d(1)    # 对邻接矩阵归一化

        self.dropout = nn.Dropout(dropout)
        self.fs = nn.Linear(embed_size, embed_size)
        self.fg = nn.Linear(embed_size, embed_size)

    def forward(self, value, key, query):
                
        # Spatial Embedding 部分
        N, T, C = query.shape
        D_S = self.embed_liner(self.D_S)
        D_S = D_S.expand(T, N, C)
        D_S = D_S.permute(1, 0, 2)

        
        # GCN 部分
        X_G = torch.Tensor(query.shape[0], 0, query.shape[2])
        self.adj = self.adj.unsqueeze(0).unsqueeze(0)
        self.adj = self.norm_adj(self.adj)
        self.adj = self.adj.squeeze(0).squeeze(0)

        # 对每个时间步的空间特征进行GCN操作,提取每个时间步的空间特征
        for t in range(query.shape[1]):
            o = self.gcn(query[ : , t,  : ],  self.adj)
            o = o.unsqueeze(1)              # shape [N, 1, C]
            X_G = torch.cat((X_G, o), dim=1)

        # Spatial Transformer 部分 Spatial embedding加到query。 原论文采用concatenated
        query = query+D_S
        attention = self.attention(value, key, query)
        # Add skip connection, run through normalization and finally dropout
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        # 添加residual block后加dropout,防止过拟合
        U_S = self.dropout(self.norm2(forward + x))

        # 融合 STransformer and GCN
        g = torch.sigmoid( self.fs(U_S) +  self.fg(X_G) )      # (7)
        out = g*U_S + (1-g)*X_G                                # (8)
        t = 1
        return out
    
class TTransformer(nn.Module):
    def __init__(self, embed_size, heads, time_num, dropout, forward_expansion):
        super(TTransformer, self).__init__()
        
        # Temporal embedding One hot
        self.time_num = time_num
        self.one_hot = One_hot_encoder(embed_size, time_num)          # temporal embedding选用one-hot方式 或者
        self.temporal_embedding = nn.Embedding(time_num, embed_size)  # temporal embedding选用nn.Embedding
        
        self.attention = TSelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, t):
        N, T, C = query.shape # 25, 12 ,64

        D_T = self.one_hot(t, N, T)                          # temporal embedding选用one-hot方式 或者
        # (12, 64)
        D_T = self.temporal_embedding(torch.arange(0, T))    # temporal embedding选用nn.Embedding
        D_T = D_T.expand(N, T, C) # (25, 12, 64)


        # temporal embedding加到query。 原论文采用concatenated
        query = query + D_T  
        
        attention = self.attention(value, key, query)

        # Add skip connection, run through normalization and finally dropout
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out


class STTransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, adj, time_num, dropout, forward_expansion):
        super(STTransformerBlock, self).__init__()
        self.STransformer = STransformer(embed_size, heads, adj, dropout, forward_expansion)
        # 这里为什么要传time_num ?
        self.TTransformer = TTransformer(embed_size, heads, time_num, dropout, forward_expansion)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, value, key, query, t):
        # # x1: (25, 12, 64)
        # 属于Post-LN Transformer
        x1 = self.norm1(self.STransformer(value, key, query) + query) # (25, 12, 64)
        x2 = self.dropout( self.norm2(self.TTransformer(x1, x1, x1, t) + x1) )
        return x2

class Encoder(nn.Module):
    # 堆叠多层 ST-Transformer Block
    def __init__(
        self,
        embed_size,
        num_layers,
        heads,
        adj,
        time_num,
        device,
        forward_expansion,
        dropout,
    ):

        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.layers = nn.ModuleList(
            [
                STTransformerBlock(
                    embed_size,
                    heads,
                    adj,
                    time_num,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, t):
        #x: input_transformer= [25, 12, 64]
        out = self.dropout(x)     # out = [25, 12, 64]
        # In the Encoder the query, key, value are all the same.

        for layer in self.layers:   # 每一个layer,就是一个STTransformer Block
            # query = value = key = out
            out = layer(out, out, out, t)
        return out     
    
class Transformer(nn.Module):
    def __init__(
        self,
        adj,
        embed_size=64,
        num_layers=3,
        heads=2,
        time_num=288,
        forward_expansion=4,
        dropout=0,
        device="cpu",
    ):
        super(Transformer, self).__init__()
        self.encoder = Encoder(
            embed_size,
            num_layers,
            heads,
            adj,
            time_num,
            device,
            forward_expansion,
            dropout,
        )
        self.device = device

    def forward(self, src, t):
        #src: input_transformer
        enc_src = self.encoder(src, t)
        return enc_src


class STTransformer(nn.Module):
    def __init__(
        self, 
        adj,
        in_channels = 1, 
        embed_size = 64, 
        time_num = 288,
        num_layers = 3,
        T_dim = 12,
        output_T_dim = 3,  
        heads = 2,        
    ):        
        super(STTransformer, self).__init__()
        # 第一次卷积扩充通道数
        self.conv1 = nn.Conv2d(in_channels, embed_size, 1)  #  kernel_size = 1
        self.Transformer = Transformer(
            adj,
            embed_size, 
            num_layers, 
            heads, 
            time_num
        )
                
        # 缩小时间维度。  例:T_dim=12到output_T_dim=3,输入12维降到输出3维
        self.conv2 = nn.Conv2d(T_dim, output_T_dim, 1)  
        # 缩小通道数,降到1维。
        self.conv3 = nn.Conv2d(embed_size, 1, 1)
        self.relu = nn.ReLU()
    
    def forward(self, x, t):
        # input x shape[ C, N, T] = [1, 25, 12]
        # C:通道数量。  N:传感器数量。  T:时间数量
        x = x.unsqueeze(0)  # (1, 1, 25, 12)
        input_Transformer = self.conv1(x)  # (1, 64, 25, 12)
        input_Transformer = input_Transformer.squeeze(0) # (64, 25, 12) = (C, N, T)
        input_Transformer = input_Transformer.permute(1, 2, 0) # (25, 12, 64) = [N, T, C]

        # src = (25, 12, 64) = [N, T, C]
        output_Transformer = self.Transformer(input_Transformer, t) # (25, 12, 64)
        output_Transformer = output_Transformer.permute(1, 0, 2) # (12, 25, 64)
        #output_Transformer shape[T, N, C]
        
        output_Transformer = output_Transformer.unsqueeze(0)     # (1, 12, 25, 64)
        out = self.relu(self.conv2(output_Transformer))    # 等号左边 out shape: [1, output_T_dim, N, C] = [1, 3, 25, 64]
        out = out.permute(0, 3, 2, 1)           # 等号左边 out shape: [1, C, N, output_T_dim] = [1, 64, 25, 3]
        out = self.conv3(out)                   # 等号左边 out shape: [1, 1, N, output_T_dim] = [1, 1, 25, 3]
        out = out.squeeze(0).squeeze(0) # (25, 3)

        return out
        # return out shape: [N, output_dim]
    


    

    
    
    

2. 我的代码

# -*- coding: utf-8 -*-
"""
Created on Mon Sep 28 10:28:06 2020

@author: wb
"""

import torch
import torch.nn as nn
from GCN_models import GCN
from One_hot_encoder import One_hot_encoder


class SSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SSelfAttention, self).__init__()
        self.embed_size = embed_size  # 64
        self.heads = heads  # 8
        self.head_dim = embed_size // heads  # 8

        assert (
                self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query):
        N, T, C = query.shape

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, T, self.heads, self.head_dim)  # embed_size维拆成 heads×head_dim
        keys = keys.reshape(N, T, self.heads, self.head_dim)
        query = query.reshape(N, T, self.heads, self.head_dim)

        values = self.values(values)  # (N, T, heads, head_dim)
        keys = self.keys(keys)  # (N, T, heads, head_dim)
        queries = self.queries(query)  # (N, T, heads, heads_dim)

        energy = torch.einsum("qthd,kthd->qkth", [queries, keys])  # 空间self-attention
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=1)  # 在K维做softmax,和为1

        out = torch.einsum("qkth,kthd->qthd", [attention, values]).reshape(
            N, T, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, T, embed_size)

        return out


class TSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(TSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
                self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query):
        N, T, C = query.shape

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, T, self.heads, self.head_dim)  # embed_size维拆成 heads×head_dim
        keys = keys.reshape(N, T, self.heads, self.head_dim)
        query = query.reshape(N, T, self.heads, self.head_dim)

        values = self.values(values)  # (N, T, heads, head_dim)
        keys = self.keys(keys)  # (N, T, heads, head_dim)
        queries = self.queries(query)  # (N, T, heads, heads_dim)

        # queries shape: (N, T, heads, heads_dim),
        # keys shape: (N, T, heads, heads_dim)
        # energy: (N, T, T, heads)
        energy = torch.einsum("nqhd,nkhd->nqkh", [queries, keys])  # 时间self-attention
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=2)  # 在K维做softmax,和为1
        out = torch.einsum("nqkh,nkhd->nqhd", [attention, values]).reshape(
            N, T, self.heads * self.head_dim
        )
        out = self.fc_out(out)
        return out


class STransformer(nn.Module):
    def __init__(self, embed_size, heads, adj, dropout, forward_expansion, time_num):
        super(STransformer, self).__init__()
        # Spatial Embedding
        self.adj = adj
        self.D_S = nn.Parameter(adj)
        self.embed_liner = nn.Linear(adj.shape[0], embed_size)
        self.temporal_embedding = nn.Embedding(time_num, embed_size)  # temporal embedding选用nn.Embedding

        self.attention = SSelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        # 调用GCN
        # input:embed_size;  hidden: embed_size*2;  outpt:embed_size
        self.gcn = GCN(embed_size, embed_size * 2, embed_size, dropout)
        self.norm_adj = nn.InstanceNorm2d(1)  # 对邻接矩阵归一化

        self.dropout = nn.Dropout(dropout)
        self.fs = nn.Linear(embed_size, embed_size)
        self.fg = nn.Linear(embed_size, embed_size)

    def forward(self, value, key, query):
        X_S = query

        # Spatial Embedding 部分
        N, T, C = query.shape
        D_S = self.embed_liner(self.D_S)
        D_S = D_S.expand(T, N, C)
        D_S = D_S.permute(1, 0, 2)
        # Temporal Embedding 部分
        D_T = self.temporal_embedding(torch.arange(0, T))
        D_T = D_T.expand(N, T, C)  # (25, 12, 64)

        # GCN 部分
        X_G = torch.Tensor(query.shape[0], 0, query.shape[2])
        self.adj = self.adj.unsqueeze(0).unsqueeze(0)
        self.adj = self.norm_adj(self.adj)
        self.adj = self.adj.squeeze(0).squeeze(0)

        # 对每个时间步的空间特征进行GCN操作,提取每个时间步的空间特征
        for t in range(query.shape[1]):
            o = self.gcn(query[:, t, :], self.adj)
            o = o.unsqueeze(1)  # shape [N, 1, C]
            X_G = torch.cat((X_G, o), dim=1)

        # Spatial Transformer 部分 Spatial embedding加到query。 原论文采用concatenated
        X_tildeS = X_S + D_S + D_T

        # Dynamical Graph Conv Layer 部分
        query = key = value = X_tildeS
        M_S = self.attention(value, key, query)
        M_S = self.dropout(self.norm1(M_S + query))
        M_tilderS = X_tildeS + M_S
        # Add skip connection, run through normalization and finally dropout
        forward = self.feed_forward(M_tilderS)
        # 添加residual block后加dropout,防止过拟合
        U_S = self.dropout(self.norm2(forward + M_tilderS))

        # 融合 STransformer and GCN
        g = torch.sigmoid(self.fs(U_S) + self.fg(X_G))  # (11)
        # 按位乘
        out = g * U_S + (1 - g) * X_G  # (12)

        return out


class TTransformer(nn.Module):
    def __init__(self, embed_size, heads, time_num, dropout, forward_expansion):
        super(TTransformer, self).__init__()

        # Temporal embedding One hot
        self.time_num = time_num
        self.temporal_embedding = nn.Embedding(time_num, embed_size)  # temporal embedding选用nn.Embedding

        self.attention = TSelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, t):
        X_T = query

        N, T, C = query.shape  # 25, 12 ,64
        D_T = self.temporal_embedding(torch.arange(0, T))
        D_T = D_T.expand(N, T, C)  # (25, 12, 64)
        # temporal embedding部分
        X_tildeT = X_T + D_T
        M_T = self.attention(X_tildeT, X_tildeT, X_tildeT)

        # Add skip connection, run through normalization and finally dropout
        M_tildeT = self.dropout(self.norm1(M_T + X_tildeT))
        forward = self.feed_forward(M_tildeT)
        U_T = self.dropout(self.norm2(forward + M_tildeT))
        return U_T


class STTransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, adj, time_num, dropout, forward_expansion):
        super(STTransformerBlock, self).__init__()
        self.STransformer = STransformer(embed_size, heads, adj, dropout, forward_expansion, time_num)
        # 这里为什么要传time_num ?
        self.TTransformer = TTransformer(embed_size, heads, time_num, dropout, forward_expansion)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, t):
        X_S = query
        # 属于Post-LN Transformer
        Y_S = self.norm1(self.STransformer(X_S, X_S, X_S) + X_S)
        X_T = Y_S + X_S
        Y_T = self.dropout(self.norm2(self.TTransformer(X_T, X_T, X_T, t) + X_T))
        return Y_T


class Encoder(nn.Module):
    # 堆叠多层 ST-Transformer Block
    def __init__(
            self,
            embed_size,
            num_layers,
            heads,
            adj,
            time_num,
            device,
            forward_expansion,
            dropout,
    ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.layers = nn.ModuleList(
            [
                STTransformerBlock(
                    embed_size,
                    heads,
                    adj,
                    time_num,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, t):
        # x: input_transformer= [25, 12, 64]
        out = self.dropout(x)  # out = [25, 12, 64]
        # In the Encoder the query, key, value are all the same.

        for layer in self.layers:  # 每一个layer,就是一个STTransformer Block
            # query = value = key = out
            out = layer(out, out, out, t)
        return out


class Transformer(nn.Module):
    def __init__(
            self,
            adj,
            embed_size=64,
            num_layers=3,
            heads=2,
            time_num=288,
            forward_expansion=4,
            dropout=0,
            device="cpu",
    ):
        super(Transformer, self).__init__()
        self.encoder = Encoder(
            embed_size,
            num_layers,
            heads,
            adj,
            time_num,
            device,
            forward_expansion,
            dropout,
        )
        self.device = device

    def forward(self, src, t):
        # src: input_transformer
        enc_src = self.encoder(src, t)
        return enc_src


class STTransformer(nn.Module):
    def __init__(
            self,
            adj,
            in_channels=1,
            embed_size=64,
            time_num=288,
            num_layers=3,
            T_dim=12,
            output_T_dim=3,
            heads=2,
    ):
        super(STTransformer, self).__init__()
        # 第一次卷积扩充通道数
        self.conv1 = nn.Conv2d(in_channels, embed_size, 1)  # kernel_size = 1
        self.Transformer = Transformer(
            adj,
            embed_size,
            num_layers,
            heads,
            time_num
        )

        # 缩小时间维度。  例:T_dim=12到output_T_dim=3,输入12维降到输出3维
        self.conv2 = nn.Conv2d(T_dim, output_T_dim, 1)
        # 缩小通道数,降到1维。
        self.conv3 = nn.Conv2d(embed_size, 1, 1)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        # input x shape[ C, N, T] = [1, 25, 12]
        # C:通道数量。  N:传感器数量。  T:时间数量
        x = x.unsqueeze(0)  # (1, 1, 25, 12)
        input_Transformer = self.conv1(x)  # (1, 64, 25, 12)
        input_Transformer = input_Transformer.squeeze(0)  # (64, 25, 12) = (C, N, T)
        input_Transformer = input_Transformer.permute(1, 2, 0)  # (25, 12, 64) = [N, T, C]

        # src = (25, 12, 64) = [N, T, C]
        output_Transformer = self.Transformer(input_Transformer, t)  # (25, 12, 64)
        output_Transformer = output_Transformer.permute(1, 0, 2)  # (12, 25, 64)
        # output_Transformer shape[T, N, C]

        output_Transformer = output_Transformer.unsqueeze(0)  # (1, 12, 25, 64)
        out = self.relu(self.conv2(output_Transformer))  # 等号左边 out shape: [1, output_T_dim, N, C] = [1, 3, 25, 64]
        out = out.permute(0, 3, 2, 1)  # 等号左边 out shape: [1, C, N, output_T_dim] = [1, 64, 25, 3]
        out = self.conv3(out)  # 等号左边 out shape: [1, 1, N, output_T_dim] = [1, 1, 25, 3]
        out = out.squeeze(0).squeeze(0)  # (25, 3)

        return out
        # return out shape: [N, output_dim]
  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值