论文阅读&STS-HGCN-AL代码详解(一)

论文原文:Spatio-Temporal-Spectral Hierarchical Graph Convolutional Network With Semisupervised Active Learning for Patient-Specific Seizure Prediction

项目地址:https://github.com/YangLibuaa/STSHGCN-AL

论文地址:Spatio-Temporal-Spectral Hierarchical Graph Convolutional Network With Semisupervised Active Learning for Patient-Specific Seizure Prediction | IEEE Journals & Magazine | IEEE Xplore

Github项目中已经附有论文原文,且图片分辨率比IEEE上的更高,推荐直接下载Github项目。

I. 简介

该论文工作量十分充足,提出了一种用于癫痫发作预测的深度学习与主动学习相结合的框架,称为STS-HGCN-AL。本文的贡献主要有以下三点:

  1. 提出了一种新型的STS-HGCN-AL方案来自动癫痫发作预测,该方案能够通过推断出代表脑电图并探索其不规则的时空反应来解决患者之间的异质性。
  2. 提出了两种变体图卷积:a)残差图卷积(ResGCN)和 b)节律注意力单元(RhythmAtt units)
  3. 对半监督的主动学习策略进行了研究,以自适应地推断患者特异性的最佳前间隔。

去除主动学习部分,STS-HGCN又可分为三个组件,分别是ST-SENet,GATENet和HGCN。图1是本文的主要框架:

图1 本文的主要框架

这篇文章将主要介绍第一个组件——ST-SENet,结合代码深入理解其构建方式。

II. 代码讲解

图2 ST-SENet

如图2所示,ST-SENet由四个模块构成,即:Temporal embedding,Multi-level spectral analysis,Multi-scale temporal analysis和Group convolution squeeze and excitation。

注意到图2左上角,ST-SENet的输入为独立成分(Independent Components, ICs),而非原始EEG信号。按照论文原文,每个通道的EEG信号是由多个源互相影响产生的,而独立成分分析法可以将EEG信号映射到相互独立的IC,每个IC都来自特定皮质区域的局部场活动,这就排除了不同源之间的干扰。例如对于N通道*T采样点EEG信号,使用ICA可将其分为N*T个独立源信号,在不改变信号尺寸的前提下去除了耦合因素。论文采用fastICA方法得到ICs。

图3 Temporal embedding

A. Temporal embedding

时域嵌入模块使用一组Temporal Convolution,即时域卷积层,实现原始ICs的时域嵌入。

打开项目中的ST_SENet.py,抽丝剥茧深入分析。

首先看程序入口:

if __name__ == "__main__":
    x = Variable(torch.randn([128, 1, 19, 1280]))
    model = ST_SENet(1, 19, 256)
    output = model(x)
    print(output.size())

x为产生的模拟数据,其维度为(batch*channels*EEG_electrodes*sample_points)。ST_SENet是一个类,负责论文中ST-SENet的具体实现。现在讨论的是Temporal embedding,因此仅关注ST_SENet的前几行:

class ST_SENet(nn.Module):
    def __init__(self, inc, chan_num, si, outc = 64, num_of_layer = 1):
        super(ST_SENet, self).__init__()  
        self.fi = math.floor(math.log2(si))
        self.embedding = Embedding_Block(Input_Layer, 
                                         Residual_Block, 
                                         num_of_layer = num_of_layer, 
                                         inc = inc, 
                                         outc = 4) 

由于论文中使用的数据集采样率为256Hz,因此猜测这里的ST_SENet(1, 19, 256)中的256为采样频率,对应形参si,chan_num应该是EEG通道数目,由于EEG是二维信号,inc设置为1即可。

self.fi对应于论文中的参数f,即2^f=采样频率,所以self.fi=8。

于是重点来到了Embedding_Block上,看一下它的定义:

def Embedding_Block(input_block, Residual_Block, num_of_layer, inc, outc):
    layers = []
    layers.append(input_block(inc = inc))
    for i in range(0, num_of_layer):
        layers.append(Residual_Block(inc = int(math.pow(2, i)*outc), 
                                     outc = int(math.pow(2, i+1)*outc)))
    return nn.Sequential(*layers) 

输入参数有5个,前两个是layer对象,num_of_layer是Residual_Block的数量,这里设置为1。inc=1,outc=4,那么传递给input_block的inc=1,Residual_Block的inc=4,outc=8。

layers列表将会得到一个Input_Layer对象和一个Residual_Block对象:

[Input_Layer(
  (conv_input): Conv2d(1, 4, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
  (bn_input): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), Residual_Block(
  (conv_expand): Conv2d(4, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (conv1): Conv2d(4, 8, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(8, 8, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
  (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)]

转到Input_Layer和Residual_Block的实现部分:

class Input_Layer(nn.Module):
    def __init__(self, inc):
        super(Input_Layer, self).__init__()
        self.conv_input = nn.Conv2d(in_channels = 1, 
                                    out_channels = 4, 
                                    kernel_size = (1, 3), 
                                    stride = 1, 
                                    padding = (0, 1), 
                                    bias = False)
        self.bn_input = nn.BatchNorm2d(4)
        self.initialize()

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight, gain = 1)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        output = self.bn_input(self.conv_input(x))
        return output
class Residual_Block(nn.Module): 
    def __init__(self, inc, outc):
        super(Residual_Block, self).__init__()
        
        if inc is not outc:
          self.conv_expand = nn.Conv2d(in_channels = inc, 
                                       out_channels = outc, 
                                       kernel_size = 1, 
                                       stride = 1, 
                                       padding = 0,
                                       bias = False)
        else:
          self.conv_expand = None          
        self.conv1 = nn.Conv2d(in_channels = inc, 
                               out_channels = outc, 
                               kernel_size = (1, 3), 
                               stride = 1, 
                               padding = (0, 1),
                               bias = False)
        self.bn1 = nn.BatchNorm2d(outc)
        self.conv2 = nn.Conv2d(in_channels = outc, 
                               out_channels = outc, 
                               kernel_size = (1, 3), 
                               stride = 1, 
                               padding = (0, 1),
                               bias = False)
        self.bn2 = nn.BatchNorm2d(outc)
        self.initialize()

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight, gain = 1)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x): 
        if self.conv_expand is not None:
          identity_data = self.conv_expand(x)
        else:
          identity_data = x
        output = self.bn1(self.conv1(x))
        output = self.conv2(output)
        output = self.bn2(torch.add(output,identity_data))
        return output 

图4  对应关系

经过一番抽丝剥茧,终于能把源码和论文对应上了!对照layers列表和图4,很显然,这里的

Input_Layer由1个卷积层(conv_input)、1个BN层(bn_input)构成。卷积核的尺寸为1*3,数量为4,对应于图中的①;

Residual_Block中的conv_expand卷积核尺寸为1*1,数量为8,对应于图中的④;

Residual_Block中的conv1卷积核尺寸为1*3,数量为8,对应于图中的②;

Residual_Block中的conv2卷积核尺寸为1*3,数量为8,对应于图中的③;

Residual_Block中的forward方法最后将图中的③和④相加。

再对比论文中的原文表述:Because convolution operators essentially equate to a lowpass filter [23], the temporal embedding block, that is, successive temporal convolution and batch normalization (BN) operations, is first adopted to infer a patient-specific optimal filter-band for the subsequent analysis.  可见Temporal embedding确实是由时域卷积和BN层构成的。

回到ST_SENet的forward部分,使用self.embedding方法执行上述步骤,并使用cat函数将时域嵌入后的ICs与原始ICs串联在一起。

def forward(self, x):    
    # Temporal embedding
    embedding_x = self.embedding(x)
    # concat raw ICs and Temporal embedding ICs
    cat_x = torch.cat((embedding_x, x), 1)

查看cat_x的维度为torch.Size([128, 9, 19, 1280]),x的维度为torch.Size([128, 1, 19, 1280]),可见执行时域嵌入产生了8个(由Temporal embedding的卷积核数量决定)尺寸为19*1280的嵌入ICs。

B. Multi-level spectral analysis

多尺度谱域分析使用小波卷积,对信号进行逐级分解,提取信号的不同频段成分。作者对于小波卷积没有做太多的了解,因此仅从应用角度出发进行阐述。

图5  多级谱域分析

Multi-level spectral analysis采用A部分中得到的ICs为输入。在ST_SENet的init部分,可见多尺度谱域分析的定义:

self.MultiLevel_Spectral = MultiLevel_Spectral(inc = 4*int(math.pow(2, num_of_layer))+inc)

计算传递给MultiLevel_Spectral的inc=9,这与上一步中cat_x的维度相对应。

查看MultiLevel_Spectral的实现:

class MultiLevel_Spectral(nn.Module): 
    def __init__(self, inc, params_path='./scaling_filter.mat'):
        super(MultiLevel_Spectral, self).__init__()
        self.filter_length = io.loadmat(params_path)['Lo_D'].shape[1]
        self.conv = nn.Conv2d(in_channels = inc, 
                              out_channels = inc*2, 
                              kernel_size = (1, self.filter_length), 
                              stride = (1, 2), padding = 0, 
                              groups = inc, 
                              bias = False)        
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                f = io.loadmat(params_path)
                Lo_D, Hi_D = np.flip(f['Lo_D'], axis = 1).astype('float32'), np.flip(f['Hi_D'], axis = 1).astype('float32')
                m.weight.data = torch.from_numpy(np.concatenate((Lo_D, Hi_D), axis = 0)).unsqueeze(1).unsqueeze(1).repeat(inc, 1, 1, 1)            
                m.weight.requires_grad = False 
    
    def self_padding(self, x):
        return torch.cat((x[:, :, :, -(self.filter_length//2-1):], x, x[:, :, :, 0:(self.filter_length//2-1)]), (self.filter_length//2-1))
                           
    def forward(self, x): 
        out = self.conv(self.self_padding(x)) 
        return out[:, 0::2,:, :], out[:, 1::2, :, :]

这部分代码实际上是小波卷积的实现。首先scaling_filter.mat这个文件存储的是db4小波的Lo_D和Hi_D系数,该系数可用下列matlab代码生成:

[Lo_D,Hi_D,Lo_R,Hi_R]=wfilters('db4');

得到Lo_D和Hi_D分别为1*8的数组,因此self.filter_length=8。将18个1*8大小的卷积核,两两一组分别赋予Lo_D和Hi_D的值,并设置卷积层的weight.requires_grad属性为False,因此卷积核的权重被固定,不可训练。这实际上是借用了Conv2d函数完成小波变换,因为离散小波变换本质上就是输入与卷积核参数做卷积。

self_padding是对数据进行扩充,保证输出的时间维度尺寸是输入时间维度尺寸的一半。可以在调试模式下进行验证。

图6  小波卷积前后维度的变化

由于组卷积的关系,输出out的前半部分为低频部分,后半部分为高频部分。因此forward返回两个值,分别是信号的低频与高频部分。

回到ST_SENet的forward部分,可以看到,Multi-level spectral analysis确实是以A部分的ICs为输入(cat_x),查看具体的分解过程:

# Multi-level spectral analysis
for i in range(1, self.fi-2):
    if i <= self.fi-7:
        if i == 1:
            out, _ = self.MultiLevel_Spectral(cat_x)
        else:
            out, _ = self.MultiLevel_Spectral(out)
    elif i == self.fi-6:
        if self.fi >= 8:
            out, gamma = self.MultiLevel_Spectral(out)
        else:
            out, gamma = self.MultiLevel_Spectral(cat_x)
    elif i == self.fi-5:
        out, beta = self.MultiLevel_Spectral(out)
    elif i == self.fi-4:
        out, alpha = self.MultiLevel_Spectral(out)
    elif i == self.fi-3:
        delta, theta = self.MultiLevel_Spectral(out)
图7  分解得到各个频段的数据维度

 通过调试模式进行验证,如图7所示,它们的通道维度都是9,这也与ICs的维度对应。

C. Multi-scale temporal analysis

图8  多尺度时域分析

多尺度时域分析仍然采用A部分得到的ICs为输入。在ST_SENet的init部分,可以看到Multi-scale temporal analysis的定义:

self.MultiScale_Temporal_gamma = MultiScale_Temporal(pow(2, self.fi-3)//8, 4*int(math.pow(2, num_of_layer))+inc)
self.MultiScale_Temporal_beta = MultiScale_Temporal(pow(2, self.fi-3)//4, 4*int(math.pow(2, num_of_layer))+inc)
self.MultiScale_Temporal_alpha = MultiScale_Temporal(pow(2, self.fi-3)//2, 4*int(math.pow(2, num_of_layer))+inc)
self.MultiScale_Temporal_theta = MultiScale_Temporal(pow(2, self.fi-3), 4*int(math.pow(2, num_of_layer))+inc)
self.MultiScale_Temporal_delta = MultiScale_Temporal(pow(2, self.fi-3), 4*int(math.pow(2, num_of_layer))+inc)  

查看Multi-scale temporal analysis的实现:

class MultiScale_Temporal(nn.Module):
    def __init__(self, kernel_size, inc):
        super(MultiScale_Temporal, self).__init__()
        self.conv = nn.Conv2d(in_channels = inc, 
                              out_channels = inc, 
                              kernel_size = (1, kernel_size), 
                              stride = (1, kernel_size), 
                              padding = (0, 0), 
                              bias = False)
        self.bn = nn.BatchNorm2d(inc) 
        self.elu = nn.ELU(inplace = True)
        self.initialize()

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight, gain = 1)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        output = self.elu(self.bn(self.conv(x)))
        return output 

MultiScale_Temporal由1个卷积层,1个BN层,1个ELU激活函数层构成。这与原文表述一致。原文为:Thus, multiscale temporal analysis, that is, temporal convolution with trainable kernel parameters, BN and, exponential linear unit (ELU) operations, captures temporal embeddings of the dynamic ICs at different scales in a data-driven way.

计算得到MultiScale_Temporal_gamma的kernel_size参数为4,inc参数为9

MultiScale_Temporal_beta的kernel_size参数为8,inc参数为9

MultiScale_Temporal_alpha的kernel_size参数为16,inc参数为9

MultiScale_Temporal_theta的kernel_size参数为32,inc参数为9

MultiScale_Temporal_delta的kernel_size参数为32,inc参数为9

之所以要为不同频段的MultiScale_Temporal层设置不同的kernel_size,是为了保持不同频段的MultiScale_Temporal层输出尺寸与对应频段的MultiLevel_Spectral输出尺寸相同。参考图6,多尺度谱域分析的gamma输出为128*9*19*320,而cat_x的维度为128*9*19*1280。为了保证多尺度时域分析在gamma频段的输出仍为128*9*19*320,应设置卷积核尺寸为1*4,步长为1*4。其他频段同理。

如图9,通过调试模式验证,证实了上述推理。否则在执行cat操作时将会发生维度不匹配的错误。

图9  验证多尺度时域分析输出尺寸

现在回到ST_SENet的forward部分,执行cat操作,将多尺度谱域分析结果(gamma, beta, alpha, theta, delta)与多尺度时域分析结果沿维度1进行串联,因此得到的(x1, x2, x3, x4, x5)应该分别是128*18*19*320,128*18*19*160,128*18*19*80,128*18*19*40,128*18*19*40。

x1 = torch.cat((self.MultiScale_Temporal_gamma(cat_x), gamma), 1)
x2 = torch.cat((self.MultiScale_Temporal_beta(cat_x), beta), 1)
x3 = torch.cat((self.MultiScale_Temporal_alpha(cat_x), alpha), 1)
x4 = torch.cat((self.MultiScale_Temporal_theta(cat_x), theta), 1)
x5 = torch.cat((self.MultiScale_Temporal_delta(cat_x), delta), 1)
图10  验证串联结果

 D. Group convolution squeeze and excitation

图11  gcSE模块

到了ST-SENet的最后一个组件,即组卷积挤压-激励模块(gcSE)。仍然结合代码讲解。在ST_SENet的init部分可以找到SENet的定义:

self.gamma_x = SENet(inc = (4*int(math.pow(2, num_of_layer))+inc)*2, outc = outc//2, kernel_size = 7)
self.beta_x = SENet(inc = (4*int(math.pow(2, num_of_layer))+inc)*2, outc = outc//2, kernel_size = 7)
self.alpha_x = SENet(inc = (4*int(math.pow(2, num_of_layer))+inc)*2, outc = outc//2, kernel_size = 3)
self.theta_x = SENet(inc = (4*int(math.pow(2, num_of_layer))+inc)*2, outc = outc//2, kernel_size = 3)
self.delta_x = SENet(inc = (4*int(math.pow(2, num_of_layer))+inc)*2, outc = outc//2, kernel_size = 3)   

仍然是对五个频段分别进行SENet操作。

传入gamma_x的inc=18,outc=32;

传入beta_x的inc=18,outc=32;

传入alpha_x的inc=18,outc=32;

传入theta_x的inc=18,outc=32;

传入delta_x的inc=18,outc=32;

可以看到,这里的inc=18是与C部分中图10的串联输出相对应的,outc=32与图11的卷积核数量相对应,不同频段的kernel_size也与图11相对应。

查看SENet的具体实现:

class SENet(nn.Module):
    def __init__(self, inc, outc, kernel_size, reduction = 8):
        super(SENet, self).__init__()
        self.conv0 = nn.Conv2d(in_channels = inc, 
                               out_channels = outc, 
                               kernel_size = (1, kernel_size), 
                               stride = (1, 1), 
                               padding = (0, kernel_size//2), 
                               groups = 2, 
                               bias = False)
        self.bn0 = nn.BatchNorm2d(outc)        
        self.se0 = SELayer(outc, reduction)
        self.conv1 = nn.Conv2d(in_channels = outc, 
                               out_channels = outc, 
                               kernel_size = (1, kernel_size), 
                               stride = (1, 1), 
                               padding = (0, kernel_size//2), 
                               groups = 2, 
                               bias = False)
        self.bn1 = nn.BatchNorm2d(outc)        
        self.se1 = SELayer(outc, reduction)
        self.conv2 = nn.Conv2d(in_channels = outc, 
                               out_channels = outc, 
                               kernel_size = (1, kernel_size), 
                               stride = (1, 1), 
                               padding = (0, kernel_size//2), 
                               groups = 2, 
                               bias = False)
        self.bn2 = nn.BatchNorm2d(outc)        
        self.se2= SELayer(outc, reduction)
        self.elu = nn.ELU(inplace = False)    
        self.initialize()

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight, gain = 1)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):            
        out = self.elu(self.se0(self.bn0(self.conv0(x))))
        out = self.elu(self.se1(self.bn1(self.conv1(out))))
        out = self.elu(self.se2(self.bn2(self.conv2(out))))
        return out

对照图11,有三个完全相同的gcSE,每个gcSE依次包括一个卷积层,一个BN层,一个SELayer和一个ELU激活函数。SELayer直接采用挤压-激励模块的原始结构:

class SELayer(nn.Module):
    '''
    Original SE block, details refer to "Jie Hu et al.: Squeeze-and-Excitation Networks"
    '''
    def __init__(self, channel, reduction = 16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias = False),
                                nn.ELU(inplace  = True),
                                nn.Linear(channel // reduction, channel, bias = False),
                                nn.Sigmoid())

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

具体原理不多介绍。回到SE_Net的forward部分,可以看到对C部分中的(x1, x2, x3, x4, x5),即多级谱域分析和多尺度时域分析串联后的结果,进行三次gcSE操作,与图11对应。

x1 = self.gamma_x(x1)
x2 = self.beta_x(x2)
x3 = self.alpha_x(x3)
x4 = self.theta_x(x4)
x5 = self.delta_x(x5)
图12  SENet输出尺度

验证SENet的输出,可见只有通道数从18变成了32,其余没有改变。

图11最后是Temporal pooling,其实实现很简单。

直接看init部分:

self.average_pooling = nn.AdaptiveAvgPool2d((chan_num, 1))

forward部分:

x1 = self.average_pooling(x1)
x2 = self.average_pooling(x2)
x3 = self.average_pooling(x3)
x4 = self.average_pooling(x4)
x5 = self.average_pooling(x5) 

通过自适应平均池化,将输入数据池化到(128, 32, 19, 1),注意nn.AdaptiveAvgPool2d的参数为输出尺寸,而非卷积核尺寸。即在时间维度上进行平均池化,如图13所示:

图13  自适应平均池化输出尺度

对照代码和图11,(x1, x2, x3, x4, x5)还需要经过一个卷积层,该卷积层有64个卷积核,大小为1*1,对应于图11中Temporal pooling左边的TConv+BN+ELU。具体代码实现为:

init部分:

self.reshapeA = nn.Sequential(nn.Conv2d(in_channels = outc//2, 
                                   out_channels = outc, 
                                   kernel_size = (1, 1), 
                                   stride = (1, 1), 
                                   padding = (0, 0), 
                                   groups = 1, 
                                   bias = False),
                              nn.BatchNorm2d(outc),
                              nn.ELU(inplace=False))
self.reshapeB = nn.Sequential(nn.Conv2d(in_channels = outc//2, 
                                   out_channels = outc, 
                                   kernel_size = (1, 1), 
                                   stride = (1, 1), 
                                   padding = (0, 0), 
                                   groups = 1, 
                                   bias = False),
                              nn.BatchNorm2d(outc),
                              nn.ELU(inplace = False))
self.reshapeD = nn.Sequential(nn.Conv2d(in_channels = outc//2, 
                                   out_channels = outc, 
                                   kernel_size = (1, 1), 
                                   stride = (1, 1), 
                                   padding = (0, 0), 
                                   groups = 1, 
                                   bias = False),
                              nn.BatchNorm2d(outc),
                              nn.ELU(inplace = False))
self.reshapeT = nn.Sequential(nn.Conv2d(in_channels = outc//2, 
                                   out_channels = outc, 
                                   kernel_size = (1, 1), 
                                   stride = (1, 1), 
                                   padding = (0, 0), 
                                   groups = 1, 
                                   bias = False),
                              nn.BatchNorm2d(outc),
                              nn.ELU(inplace = False))
self.reshapeG = nn.Sequential(nn.Conv2d(in_channels = outc//2, 
                                   out_channels = outc, 
                                   kernel_size = (1, 1), 
                                   stride = (1, 1), 
                                   padding = (0, 0), 
                                   groups = 1, 
                                   bias = False),
                              nn.BatchNorm2d(outc),
                              nn.ELU(inplace = False))

forward部分:

x1 = self.reshapeG(x1)
x2 = self.reshapeB(x2)
x3 = self.reshapeA(x3)
x4 = self.reshapeT(x4)
x5 = self.reshapeD(x5)

验证输出尺度:

图14

此时的(x1, x2, x3, x4, x5)即为图11中的U,即五个频段经过ST-SENet后的ICs,其尺度为(19*64),与原论文中(E*64),E为通道数相对应。

另外需要注意的一点是,ST_SENet的forward操作最后将(x1, x2, x3, x4, x5)串联并返回,所以返回值的尺寸是torch.Size([128, 320, 1, 19]):

return torch.cat((x1, x2, x3, x4, x5), 1).permute(0, 1, 3, 2).contiguous()

这里用到的contiguous()函数是断开返回值与(x1, x2, x3, x4, x5)的联系,即此后(x1, x2, x3, x4, x5)的值的改变也不会影响返回值。

至此,ST-SENet的介绍就全部结束了。在下一篇文章中,我们将继续探索本文中其他模块(GATENet、HGCN)的结构。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值