57、通过EEG数据的SHAPE变化,揭开EEG-TCNet的黑匣子[看好了小子,我只教这一次]

之前在第18篇博客中对于EEG-TCNet这个处理EEG信号的sota模型进行了介绍,也给出了模型,目前也是全网对于EEG-TCNet浏览度最高的文章了,我觉得讲的已经很细致了,没想到还是有不少同学疑问,这也是全网缺少该模型pytorch代码的原因,因为pytorch中没有封装TCN模块,无法直接调用,而在Tensorflow中可直接调用,废话不多少,上菜:

EEG-TCNet模型图:

原论文EEG-TCNet结构图

模型结构分析:

1、BCI IV2a数据以4维数据输入,shape=(288,1,22,1000)

2、数据先经过一个完整的EEGNet结构(时间卷积+深度卷积+深度可分离),来处理这个4维数组

3、数据从EEGNet出来,进入到TCN块之前进行降维处理(TCN只能处理1维数组)

下面我们来看2a数据以(batch_size,1,22,1000)输入到EEG-TCNet中是如何改变shape的

我自己写的EEG-TCNet代码模型-结构图:(我自己画的,别盗图)

TCN块(膨胀因果卷积)分析:

代码编写:

Chomp1d(nn.Module):裁剪类

TemporalBlock(nn.Module):TCN主体类,调用Chomp1d(),在这个类使用的卷积是Conv1d

TemporalConvNet(nn.Module):调用TemporalBlock()TCN完全体类

TCNNet(nn.Module):调用TemporalConvNet(),降维,使得TCN完全体跑的通


讲上面这4个类,我要倒着讲,费点劲要:(为啥倒着讲?同学想想 0。0 )

input_data = batch_size,1,22,1000经过一个前3个block后,此时控制台输出shape = 32,8,1,31断点如下:

数据此时还是4维的,所以我们在这使用if来判断维度,给他降维度

1、Data = torch.rand(x.shape):生成一个空的和x的维度一致的张量数据,用来存储for循环TCN块裁剪的数据

2、空的张量数据也要送到GPU中,否则报错,因为此时X的数据都在GPU上

3、在x的第二维度channel = 1,进行for循环,通过self来调用类内的tcn_block对应的TCN方法,对x数据进行裁剪并提取数据,把这些数据(此时还是4维)送给张量data

4、x = data(乾坤大魔移!


tcn_block对应着咱们定义的TemporalConvNet() 完全体这个类,如下:

类里面调用了上面定义好的Chomp1d()这个裁剪的类

此时代码跑到了Chomp1d()里面,如下所示:

TCN之前的数据= 32,8,1,31

此时数据维度 = 32,8,40,这里代码自动的去掉了通道=1的维度,并+res这个对x下采样的数据

因为这里是for i in range(x.shape[2])的循环,此时i=0,x.shape[3] = 40,我们再进入下一个循环i=1看看

此时x.shape[3] = 49,所以就这样,在送到Chomp1d()进行裁剪时,x加上了res这个下采样特征数据,导致了x的数据量增加,我们规定了Chomp1d()中的chomp_size这个数值,只保留与原始数据总量相同的前chomp_size的这个数目,来最后送给Fc层做最后结果的输出

此时我们送给Fc的shape :

又变回原来的31个数据了,这事裁剪类的功劳!但此时前后的这个31数据是不同的,多了下采样的特征,所以TCNNet这个类实现了先降维再生维的神奇操作,使得代码流通,完事。

全部代码如下:

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
       
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
       
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1,self.relu1, self.dropout1,
                                 self.conv2, self.chomp2,self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
    #   self.init_weights()

    # def init_weights(self):
    #     self.conv1.weight.data.normal_(0, 0.01)
    #     self.conv2.weight.data.normal_(0, 0.01)
    #     if self.downsample is not None:
    #         self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, padding=(kernel_size-1) * dilation_size, 
                                     dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

import numpy as np

class TCNNet(nn.Module):
     def __init__ (self,*args) -> None:
        super(TCNNet,self).__init__()
        if len(args) < 2:
            print('error')
            exit()
        else:
            num_inputs = args[0]
            num_channels = args[1]
            kernel_size = int(args[2][0])
            
        self.tcn_block =  TemporalConvNet(num_inputs,num_channels,kernel_size) 
        #self.tcn_block =  TemporalConvNet(num_inputs=self.F2,num_channels=[tcn_filters,tcn_filters],kernel_size=tcn_kernelSize) 
     def forward(self,x) :
        if len(x.shape) == 4:
            data = torch.rand(x.shape)
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
            data = data.to(device)
 
            for i in range(x.shape[2]):
              
                data[:,:,i,:] = self.tcn_block(x[:,:,i,:])
            x = data
        else:
            x = torch.squeeze(x,dim=2) 
            x = self.tcn_block(x)
        
        return x 
  • 21
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

是馒头阿

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值