改进:2-SE‘inception -Stacked LSTM 新型网络模型介绍与编程实现

改进一: 给每个inception块分支加入合适得通道注意力机制

它在同一个模块(分支)中对不同的特征通道进行加权,动态给予不同通道不同的重要性。这种注意力机制也被称为通道注意力(channel-wise attention)。注意力机制的核心思想是将不同的特征通道加权和,使得一些重要的特征通道得到更加充分的利用,进而提高模型的性能。
在每个分支顶部都添加了一个注意力机制,分别是 att_a20、att_b20、att_c20 和 att_d20。它们的模型结构相同,都是由两个卷积层和一个Sigmoid函数组成的。

如图所示:

在这里插入图片描述

具体如下:

self.att_??? = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

分为以下几步:

将输入的特征表示进行一次卷积(nn.Conv1d(64, 16, kernel_size=1, stride=1)),其中 64 是输入特征的通道数,16 是输出特征的通道数,kernel_size=1 表示卷积核大小为 1 ×1,stride=1 表示步长为 1,padding=0 表示不填充。
对卷积后的结果进行 ReLU 激活。
再进行一次卷积,将通道数从 16 恢复到 64(nn.Conv1d(16, 64, kernel_size=1, stride=1))
最后,使用 Sigmoid 函数将输出限定在 [0,1] 范围内,得到通道权重。
总之,**att_???**模块 通过两个卷积层对输入的特征进行通道维度的压缩和恢复,然后使用 Sigmoid 函数将输出权重限定在[0,1]范围内,进而实现对输入特征在通道维度的加权。最后将得到的注意力加权特征供后续模型层级使用,以提高模型的准确率。

结合我们原来的inception块分路:

self.branch??? = nn.Sequential(???)

进行结合输出:

out_?? = self.att???(self.branch???()) * self.branch???()

每一个分支都这样改,就完成了第一种改进,即:

在这里插入图片描述

改进二:Stack通道,有以下两种方式

如图所示:

在这里插入图片描述在这里插入图片描述

改进的第二种传播代码如下:

    def forward(self, x):
        out_a10 = self.att_a10(self.branch_a10(x)) * self.branch_a10(x)
        out_b10 = self.att_b10(self.branch_b10(x)) * self.branch_b10(x)
        out_c10 = self.att_c10(self.branch_c10(x)) * self.branch_c10(x)
        out_d10 = self.att_d10(self.branch_d10(x)) * self.branch_d10(x)
        y10 = torch.cat((out_a10, out_b10, out_c10, out_d10), dim=1)  # 按通道拼接
        out_a20 = self.att_a20(self.branch_a20(y10)) * self.branch_a20(y10)
        out_b20 = self.att_b20(self.branch_b20(y10)) * self.branch_b20(y10)
        out_c20 = self.att_c20(self.branch_c20(y10)) * self.branch_c20(y10)
        out_d20 = self.att_d20(self.branch_d20(y10)) * self.branch_d20(y10)
        y20 = torch.cat((out_a20, out_b20, out_c20, out_d20), dim=1)
        # 将4个分支的输出合并起来
        out_a30 = self.att_a30(self.branch_a30(y20)) * self.branch_a30(y20)
        out_b30 = self.att_b30(self.branch_b30(y20)) * self.branch_b30(y20)
        out_c30 = self.att_c30(self.branch_c30(y20)) * self.branch_c30(y20)
        out_d30 = self.att_d30(self.branch_d30(y20)) * self.branch_d30(y20)
        y30 = torch.cat((out_a30, out_b30, out_c30, out_d30), dim=1)

        out_a40 = self.att_a40(self.branch_a40(y30)) * self.branch_a40(y30)
        out_b40 = self.att_b40(self.branch_b40(y30)) * self.branch_b40(y30)
        out_c40 = self.att_c40(self.branch_c40(y30)) * self.branch_c40(y30)
        out_d40 = self.att_d40(self.branch_d40(y30)) * self.branch_d40(y30)
        y40 = torch.cat((out_a40, out_b40, out_c40, out_d40), dim=1)

        out_a11 = self.att_a11(self.branch_a11(x)) * self.branch_a11(x)
        out_b11 = self.att_b11(self.branch_b11(x)) * self.branch_b11(x)
        out_c11 = self.att_c11(self.branch_c11(x)) * self.branch_c11(x)
        out_d11 = self.att_d11(self.branch_d11(x)) * self.branch_d11(x)
        y11 = torch.cat((out_a11, out_b11, out_c11, out_d11), dim=1)
        out_a21 = self.att_a21(self.branch_a21(y11)) * self.branch_a21(y11)
        out_b21 = self.att_b21(self.branch_b21(y11)) * self.branch_b21(y11)
        out_c21 = self.att_c21(self.branch_c21(y11)) * self.branch_c21(y11)
        out_d21 = self.att_d21(self.branch_d21(y11)) * self.branch_d21(y11)
        y21 = torch.cat((out_a21, out_b21, out_c21, out_d21), dim=1)

        out_a31 = self.att_a31(self.branch_a31(y21)) * self.branch_a31(y21)
        out_b31 = self.att_b31(self.branch_b31(y21)) * self.branch_b31(y21)
        out_c31 = self.att_c31(self.branch_c31(y21)) * self.branch_c31(y21)
        out_d31 = self.att_d31(self.branch_d31(y21)) * self.branch_d31(y21)
        y31 = torch.cat((out_a31, out_b31, out_c31, out_d31), dim=1)

        out_a41 = self.att_a41(self.branch_a41(y31)) * self.branch_a41(y31)
        out_b41 = self.att_b41(self.branch_b41(y31)) * self.branch_b41(y31)
        out_c41 = self.att_c41(self.branch_c41(y31)) * self.branch_c41(y31)
        out_d41 = self.att_d41(self.branch_d41(y31)) * self.branch_d41(y31)
        y41 = torch.cat((out_a41, out_b41, out_c41, out_d41), dim=1)

        y401=torch.cat((y40,y41),dim=1).permute(0, 2, 1)
        # pdb.set_trace()
        y50, _ = self.lstm1(y401)
        y50 = self.dropout1(y50)
        # pdb.set_trace()
        y51, _ = self.lstm11(y401)
        y51 = self.dropout11(y51)
        y50=y50.permute(0, 2, 1)
        y51 = y51.permute(0, 2, 1)
        # pdb.set_trace()
        y501 = torch.cat((y50, y51), dim=1)
        y501=y501.permute(0, 2, 1)
        # pdb.set_trace()
        y60, _ = self.lstm2(y501)
        y60 = self.dropout2(y60)
        y61, _ = self.lstm21(y501)
        y61 = self.dropout21(y61)

        out = y60 + y61
        out = self.dense(out)
        return out

关键提示:

由于第一次输入到lstm由于维度输入顺序对不上要进行维度转换,对与两个lstm模块的stacking过程中更及时进行再次转换,不然touch.cat会拼接错误,具体就是下面两行:
y50=y50.permute(0, 2, 1)
y51 = y51.permute(0, 2, 1)
当然还可以随时用pdb.set_trace()来查看维度,使用在终端输入即可,如下图:
pbd

总体model代码部分如下:

import torch
import torch.nn as nn
import pdb

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.branch_a10 = nn.Sequential(
            nn.Conv1d(15, 64, kernel_size=1, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_b10 = nn.Sequential(
            nn.Conv1d(15, 64, kernel_size=1, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_c10 = nn.Sequential(
            nn.AvgPool1d(kernel_size=3, stride=3, padding=1),
            nn.Conv1d(15, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.branch_d10 = nn.Sequential(
            nn.Conv1d(15, 64, kernel_size=1, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )

        self.att_a10 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_b10 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_c10 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_d10 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        # 定义20四个分支
        self.branch_a20 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=3, padding=0),
            nn.ReLU()
        )

        self.branch_b20 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_c20 = nn.Sequential(
            nn.AvgPool1d(kernel_size=3, stride=3, padding=1),
            nn.Conv1d(256, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.branch_d20 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )
        self.att_a20 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_b20 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_c20 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_d20 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.branch_a30 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_b30 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_c30 = nn.Sequential(
            nn.AvgPool1d(kernel_size=3, stride=2, padding=1),
            nn.Conv1d(256, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.branch_d30 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )
        self.att_a30 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_b30 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_c30 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_d30 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )


        self.branch_a40 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_b40 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_c40 = nn.Sequential(
            nn.AvgPool1d(kernel_size=3, stride=3, padding=1),
            nn.Conv1d(256, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.branch_d40 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )
        self.att_a40 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_b40 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_c40 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_d40 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )


        # y50
        self.lstm1 = nn.LSTM(input_size=512, hidden_size=128, num_layers=1, batch_first=True)
        self.dropout1 = nn.Dropout(p=0.5)
        self.lstm2 = nn.LSTM(input_size=256, hidden_size=128, num_layers=1, batch_first=True)
        self.dropout2 = nn.Dropout(p=0.5)

        self.branch_a11 = nn.Sequential(
            nn.Conv1d(15, 64, kernel_size=1, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_b11 = nn.Sequential(
            nn.Conv1d(15, 64, kernel_size=1, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_c11 = nn.Sequential(
            nn.AvgPool1d(kernel_size=3, stride=3, padding=1),
            nn.Conv1d(15, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.branch_d11 = nn.Sequential(
            nn.Conv1d(15, 64, kernel_size=1, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )
        self.att_a11 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_b11 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_c11 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_d11 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )
        # 定义20四个分支
        self.branch_a21 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=3, padding=0),
            nn.ReLU()
        )

        self.branch_b21 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_c21 = nn.Sequential(
            nn.AvgPool1d(kernel_size=3, stride=3, padding=1),
            nn.Conv1d(256, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.branch_d21 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )
        self.att_a21 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_b21 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_c21 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_d21 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.branch_a31 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_b31 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_c31 = nn.Sequential(
            nn.AvgPool1d(kernel_size=3, stride=2, padding=1),
            nn.Conv1d(256, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.branch_d31 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )
        self.att_a31 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_b31 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_c31 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_d31 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )


        self.branch_a41 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_b41 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )

        self.branch_c41 = nn.Sequential(
            nn.AvgPool1d(kernel_size=3, stride=3, padding=1),
            nn.Conv1d(256, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.branch_d41 = nn.Sequential(
            nn.Conv1d(256, 64, kernel_size=1, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU()
        )
        self.att_a41 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_b41 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_c41 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

        self.att_d41 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv1d(16, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )
        # y51
        self.lstm11 = nn.LSTM(input_size=512, hidden_size=128, num_layers=1, batch_first=True)
        self.dropout11 = nn.Dropout(p=0.5)
        self.lstm21 = nn.LSTM(input_size=256, hidden_size=128, num_layers=1, batch_first=True)
        self.dropout21 = nn.Dropout(p=0.5)

        self.dense = nn.Linear(in_features=128, out_features=64)

    def forward(self, x):
        out_a10 = self.att_a10(self.branch_a10(x)) * self.branch_a10(x)
        out_b10 = self.att_b10(self.branch_b10(x)) * self.branch_b10(x)
        out_c10 = self.att_c10(self.branch_c10(x)) * self.branch_c10(x)
        out_d10 = self.att_d10(self.branch_d10(x)) * self.branch_d10(x)
        y10 = torch.cat((out_a10, out_b10, out_c10, out_d10), dim=1)  # 按通道拼接
        out_a20 = self.att_a20(self.branch_a20(y10)) * self.branch_a20(y10)
        out_b20 = self.att_b20(self.branch_b20(y10)) * self.branch_b20(y10)
        out_c20 = self.att_c20(self.branch_c20(y10)) * self.branch_c20(y10)
        out_d20 = self.att_d20(self.branch_d20(y10)) * self.branch_d20(y10)
        y20 = torch.cat((out_a20, out_b20, out_c20, out_d20), dim=1)
        # 将4个分支的输出合并起来
        out_a30 = self.att_a30(self.branch_a30(y20)) * self.branch_a30(y20)
        out_b30 = self.att_b30(self.branch_b30(y20)) * self.branch_b30(y20)
        out_c30 = self.att_c30(self.branch_c30(y20)) * self.branch_c30(y20)
        out_d30 = self.att_d30(self.branch_d30(y20)) * self.branch_d30(y20)
        y30 = torch.cat((out_a30, out_b30, out_c30, out_d30), dim=1)

        out_a40 = self.att_a40(self.branch_a40(y30)) * self.branch_a40(y30)
        out_b40 = self.att_b40(self.branch_b40(y30)) * self.branch_b40(y30)
        out_c40 = self.att_c40(self.branch_c40(y30)) * self.branch_c40(y30)
        out_d40 = self.att_d40(self.branch_d40(y30)) * self.branch_d40(y30)
        y40 = torch.cat((out_a40, out_b40, out_c40, out_d40), dim=1)

        out_a11 = self.att_a11(self.branch_a11(x)) * self.branch_a11(x)
        out_b11 = self.att_b11(self.branch_b11(x)) * self.branch_b11(x)
        out_c11 = self.att_c11(self.branch_c11(x)) * self.branch_c11(x)
        out_d11 = self.att_d11(self.branch_d11(x)) * self.branch_d11(x)
        y11 = torch.cat((out_a11, out_b11, out_c11, out_d11), dim=1)
        out_a21 = self.att_a21(self.branch_a21(y11)) * self.branch_a21(y11)
        out_b21 = self.att_b21(self.branch_b21(y11)) * self.branch_b21(y11)
        out_c21 = self.att_c21(self.branch_c21(y11)) * self.branch_c21(y11)
        out_d21 = self.att_d21(self.branch_d21(y11)) * self.branch_d21(y11)
        y21 = torch.cat((out_a21, out_b21, out_c21, out_d21), dim=1)

        out_a31 = self.att_a31(self.branch_a31(y21)) * self.branch_a31(y21)
        out_b31 = self.att_b31(self.branch_b31(y21)) * self.branch_b31(y21)
        out_c31 = self.att_c31(self.branch_c31(y21)) * self.branch_c31(y21)
        out_d31 = self.att_d31(self.branch_d31(y21)) * self.branch_d31(y21)
        y31 = torch.cat((out_a31, out_b31, out_c31, out_d31), dim=1)

        out_a41 = self.att_a41(self.branch_a41(y31)) * self.branch_a41(y31)
        out_b41 = self.att_b41(self.branch_b41(y31)) * self.branch_b41(y31)
        out_c41 = self.att_c41(self.branch_c41(y31)) * self.branch_c41(y31)
        out_d41 = self.att_d41(self.branch_d41(y31)) * self.branch_d41(y31)
        y41 = torch.cat((out_a41, out_b41, out_c41, out_d41), dim=1)

        y401=torch.cat((y40,y41),dim=1).permute(0, 2, 1)
        # pdb.set_trace()
        y50, _ = self.lstm1(y401)
        y50 = self.dropout1(y50)
        # pdb.set_trace()
        y51, _ = self.lstm11(y401)
        y51 = self.dropout11(y51)
        y50=y50.permute(0, 2, 1)
        y51 = y51.permute(0, 2, 1)
        # pdb.set_trace()
        y501 = torch.cat((y50, y51), dim=1)
        y501=y501.permute(0, 2, 1)
        # pdb.set_trace()
        y60, _ = self.lstm2(y501)
        y60 = self.dropout2(y60)
        y61, _ = self.lstm21(y501)
        y61 = self.dropout21(y61)

        out = y60 + y61
        out = self.dense(out)
        return out


if __name__ == '__main__':
    # 定义输入
    x = torch.randn((16, 15, 64))
    # 创建模型
    model = MyModel()
    # 进行前向传播
    y = model(x)
    # 打印输出的shape
    print(y.shape)

请添加图片描述

参考文献:chatgpt

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值