代码总结||深度学习||计算机视觉||如何在医学影像3DCNN中插入网上已有的注意力机制

前言

在做深度学习在医学影像上应用的时候,数据集往往是3D的,而网上很多公开的trick或者注意力机制都是2D实现的,因此带来了一些困难。

做法

在github(https://github.com/xmu-xiaoma666/External-Attention-pytorch)上可以看到很多现有即插即用的注意力模块。
git clone到本地后,关键就是怎么在已有的代码框架上插入这些注意力模块。
**step0:**实例化一个模型,利用torch.randn函数模拟输入。我这里输入的是一个batchsize为128,channel为1,x为24,y为24,z为20的图像。

if __name__ == '__main__':
    input=torch.randn(128,1, 24, 24, 20)
    net=DiscriNet()
    out=net(input)
    print(out.shape)

**step1:**知道要插的位置,一般在backbone的后面。在def forward(self, x)方法中,利用print(x.size())知道相应的输入尺寸。

    def forward(self, x):
        #print(x.size()) #torch.Size([128, 1, 20, 20, 16]) 128为batchsize 1为channel
        x = self.net(x)
        # print(x.size())
        # print(x.size()) #torch.Size([128, 128, 2, 2, 2])
        if self.is_fc:
            x = x.view(-1, 128 * 2 * 2 * 2)
            x = self.final(x)
        else:
            #print(self.final(x).size()) #torch.Size([128, 2, 1, 1, 1])
            #print(self.final(x).squeeze(4).size()) #torch.Size([128, 2, 1, 1])
            #print(self.final(x).squeeze(4).squeeze(3).squeeze(2).size())#torch.Size([128, 2])
            x = self.final(x).squeeze(4).squeeze(3).squeeze(2)
            #print(x.size()) #torch.Size([128, 2])
        return x

**step2:**根据已经封装好的模块,将它实例化后插入到forward方法中。

    def forward(self, x):
        from SelfAttention import ScaledDotProductAttention
        #print(x.size())
        x = self.net(x)
        #print(x.size()) #torch.Size([128, 128, 4, 4, 4]) 
        b = x.permute(0, 2, 3, 4, 1).reshape(128, -1, 128)
        sa = ScaledDotProductAttention(d_model=128, d_k=128, d_v=128, h=8)
        output = sa(b, b, b)
        #print(output.size())
        x = x.reshape(128, 4, 4, 4, 128).permute(0, 4, 1, 2, 3)
        # print(x.size())
        #print(x.equal(a))
        x = x.contiguous().view(-1, 128 * 4 * 4 * 4)
        #(x.size())
        x = self.final(x)
        return x

利用permute、reshape、view函数,就可以转换了。我这里用的是自注意力机制,因为其本身是从自然语言领域转换而来的,无论是2D还是3D它都会转换成1D的tensor。
BUG1 RuntimeError: Expected all tensors to be on the same device, but found at least two devices
一般是tensor一个在cpu,一个在gpu上报错。

device = torch.device('cuda:0')
emsa = EMSA(d_model=128, d_k=128, d_v=128, h=8, H=8, W=8, ratio=2, apply_transform=True).to(device)

让模型在gpu上运行。
BUG2代码可以运行,但运行到最后一个epoch报错
这是由于最后一个epoch的batchsize不到指定数量所导致的。我batchsize设置为128,最后一个epoch不到128,因此后面的转换就会报错。

for img_batch, label_batch in dataloader:
            if(len(img_batch)==self.batch_size):

加一个判断就好。
或者不要固定batchsize

    def forward(self, x):
        from SelfAttention import ScaledDotProductAttention
        #print(x.size())
        x = self.net(x)
        x_batchsize=x.size()[0]
        device = torch.device('cuda:0')
        b = x.permute(0, 2, 3, 4, 1).reshape(x_batchsize, -1, 128).to(device)
        sa = ScaledDotProductAttention(d_model=128, d_k=128, d_v=128, h=8).to(device)
        output = sa(b, b, b)
        #print(output.size())
        x = output.reshape(x_batchsize, 4, 4, 4, 128).permute(0, 4, 1, 2, 3)
        # print(x.size())
        #print(x.equal(a))
        x = x.contiguous().view(-1, 128 * 4 * 4 * 4)

        #(x.size())
        x = self.final(x)
        return x

------------------------------------------------------------------------更新-------------------------------------------------
后续可以考虑把注意力模块封装一下,然后相应的参数写成动态的,如

 x_batchsize=x.size()[0]
 x_channel = bot.size()[1]
  • 3
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值