pytorch Feature Fusion

参加实验室的第一个任务是CNN图像识别以及进阶的特征融合,由于是CV领域初学者且是自学,在此期间花费了很多时间和精力,有很多东西需要记录,方便查看

任务描述:有一个数据集,一共117个样本,它是MRI影响经过处理后的数据集,分为2D和3D,对于2D也就是我们常说的图像,是从3维图像的三个面截取的切片数据,也就是说,三个数据对应了一个物体的三个视角,首先要做的第一步是从单个视角对图像二分类,然后将三个面的深度特征融合后进行softmax回归分类

单个面的分类不做记录,仅记录特征融合及其各种问题

2022年9月21日更新:最近翻开了几个月以前写的内容,发现那个时候的由于技术不到位写的简单且一般,有新的知识补充,但由于已经过去几个月故不在更改原类容,只是新增类容,因此旧内容可能存在问题;新增内容大致是增加了两个融合方法和一些训练的技巧和训练过程中可能遇到的问题

旧部分

模型部分

简单的concat,只需要将通过三个特征提取器也就是我们常说的CNN的最后一个卷积层输出的结果展平之后在特征维度上相加,由于我们是对同一个物体的三个视图进行特征融合,所以采用三个相同模型resnet18

from torch.nn import functional as F
import torch.nn as nn
import torch
import torchvision

class MyEnsemble(nn.Module):


    def __init__(self, nb_classes=2):
        super(MyEnsemble, self).__init__()
        self.model1 = torchvision.models.resnet18(pretrained=False)
        self.model2 = torchvision.models.resnet18(pretrained=False)
        self.model3 = torchvision.models.resnet18(pretrained=False)
        # Freeze the models
        for param in self.model1.parameters():
            param.requires_grad = True
        for param in self.model2.parameters():
            param.requires_grad = True
        for param in self.model3.parameters():
            param.requires_grad = True

        # create a new classifier
        self.model1.fc = nn.Identity()
        self.model2.fc = nn.Identity()
        self.model3.fc = nn.Identity()
        self.classifier = nn.Linear(512*3, nb_classes)

    def forward(self, x1, x2, x3, mode=None):
        x1 = self.model1(x1)
        # x1 = x1[0]
        x1 = torch.flatten(x1, 1)

        x2 = self.model2(x2)
        # x2 = x2[0]
        # x2 = x2.view(x2.size(0), -1)
        x2 = torch.flatten(x2, 1)

        x3 = self.model3(x3)
        # x3 = x3[0]
        # x3 = x3.view(x3.size(0), -1)
        x3 = torch.flatten(x3, 1)

        x = torch.cat((x1, x2, x3), dim=1)
        if mode == 'val':
            torch.save(x, './result/feature.pt')
        x = self.classifier(torch.sigmoid(x))
        x = F.softmax(x, dim=1)   
        return x


    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m, a=0, mode='fan_in', nonlinearity='leaky_relu')

然后在测试输出时,得到了报错:AttributeError: 'tuple' object has no attribute 'view',报错的原因是网络经过pooling层之后的输出是一个tuple类型:如下为输出结果(print大法是真的香)
在这里插入图片描述
此时我们要做的是,将tuple里面的tensor提取出来,由于这里tuple里面只有一个元素,一个简单的语法就可以了:x1=x1[0],以下为将tuple中tensor提取出来后的前后对比
又认真研究了一下,原来是我太菜,在forward方法里面,每个语句后面都加了逗号,是不需要逗号的啊喂,好菜好菜!写在这里警醒自己,不过话说回来tuple提取方法也是可以的
前后对比
除此之外,还有两个地方最初不是很理解

问题

for param in self.model1.parameters():
            param.requires_grad = True

self.model3.fc = nn.Identity()

(1)param.requires_grad = True,表示允许参数更新,也就是说允许你在训练时通过反向传播来更新参数,反之就是冻结
(2)self.model3.fc = nn.Identity(),去pytorch的论坛找了一下,别人给出的解释是 “该模块将只返回输入,无需任何操作,可用于例如替换其他层”,我的理解是,在你需要替换掉网络时,你需要使用该语句nn.Identity(),将原有的网络结构保持,然后替换掉你想替换的层例如我的
(3)最后一层是否要加softmax回归,以及全连接之后是否激活层,首先softmax不会对结果产生任何影响,如果加上softmax之后可以加上激活层,但是ReLU会使得模型效果更差,而Sigmoid则使得模型效果表现好

self.model1.fc = nn.Identity()
        self.model2.fc = nn.Identity()
        self.model3.fc = nn.Identity()
        self.classifier = nn.Linear(512*3, nb_classes)

就是将卷积操作保持,然后将三个深度特征融合以后用一个新的全连接层替换掉原有全连接层

然后接下来我要查看我的模型到底是的组成的,有没有按照我预想的来:
torchsummary.summary方法
完整代码:

查看输出结果以及网络结构

import torch
from my_model import MyEnsemble
from torchsummary import summary

def train():
    device = torch.device('cuda:0')
    x1 = torch.randn(10, 3, 224, 224)
    x2 = torch.randn(10, 3, 224, 224)
    x3 = torch.randn(10, 3, 224, 224)
    model = MyEnsemble()
    outputs = model(x1, x2, x3)
    print(outputs)
    model.to(device)
    summary(model, [(3, 224, 224), (3, 224, 224), (3, 224, 224)])

if __name__=='__main__':
    train()

输出结果如下:可以看到完整的网络结构

tensor([[ 0.3332, -0.7907],
        [ 0.1526, -0.9038],
        [ 0.2006, -0.9861],
        [ 0.0889, -0.8982],
        [ 0.2047, -0.7573],
        [ 0.0432, -0.7982],
        [ 0.1041, -0.9219],
        [ 0.1470, -0.8201],
        [ 0.2360, -0.9237],
        [ 0.2444, -0.8210]], grad_fn=<AddmmBackward>)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15           [-1, 64, 56, 56]          36,864
      BatchNorm2d-16           [-1, 64, 56, 56]             128
             ReLU-17           [-1, 64, 56, 56]               0
       BasicBlock-18           [-1, 64, 56, 56]               0
           Conv2d-19          [-1, 128, 28, 28]          73,728
      BatchNorm2d-20          [-1, 128, 28, 28]             256
             ReLU-21          [-1, 128, 28, 28]               0
           Conv2d-22          [-1, 128, 28, 28]         147,456
      BatchNorm2d-23          [-1, 128, 28, 28]             256
           Conv2d-24          [-1, 128, 28, 28]           8,192
      BatchNorm2d-25          [-1, 128, 28, 28]             256
             ReLU-26          [-1, 128, 28, 28]               0
       BasicBlock-27          [-1, 128, 28, 28]               0
           Conv2d-28          [-1, 128, 28, 28]         147,456
      BatchNorm2d-29          [-1, 128, 28, 28]             256
             ReLU-30          [-1, 128, 28, 28]               0
           Conv2d-31          [-1, 128, 28, 28]         147,456
      BatchNorm2d-32          [-1, 128, 28, 28]             256
             ReLU-33          [-1, 128, 28, 28]               0
       BasicBlock-34          [-1, 128, 28, 28]               0
           Conv2d-35          [-1, 256, 14, 14]         294,912
      BatchNorm2d-36          [-1, 256, 14, 14]             512
             ReLU-37          [-1, 256, 14, 14]               0
           Conv2d-38          [-1, 256, 14, 14]         589,824
      BatchNorm2d-39          [-1, 256, 14, 14]             512
           Conv2d-40          [-1, 256, 14, 14]          32,768
      BatchNorm2d-41          [-1, 256, 14, 14]             512
             ReLU-42          [-1, 256, 14, 14]               0
       BasicBlock-43          [-1, 256, 14, 14]               0
           Conv2d-44          [-1, 256, 14, 14]         589,824
      BatchNorm2d-45          [-1, 256, 14, 14]             512
             ReLU-46          [-1, 256, 14, 14]               0
           Conv2d-47          [-1, 256, 14, 14]         589,824
      BatchNorm2d-48          [-1, 256, 14, 14]             512
             ReLU-49          [-1, 256, 14, 14]               0
       BasicBlock-50          [-1, 256, 14, 14]               0
           Conv2d-51            [-1, 512, 7, 7]       1,179,648
      BatchNorm2d-52            [-1, 512, 7, 7]           1,024
             ReLU-53            [-1, 512, 7, 7]               0
           Conv2d-54            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-55            [-1, 512, 7, 7]           1,024
           Conv2d-56            [-1, 512, 7, 7]         131,072
      BatchNorm2d-57            [-1, 512, 7, 7]           1,024
             ReLU-58            [-1, 512, 7, 7]               0
       BasicBlock-59            [-1, 512, 7, 7]               0
           Conv2d-60            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-61            [-1, 512, 7, 7]           1,024
             ReLU-62            [-1, 512, 7, 7]               0
           Conv2d-63            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-64            [-1, 512, 7, 7]           1,024
             ReLU-65            [-1, 512, 7, 7]               0
       BasicBlock-66            [-1, 512, 7, 7]               0
AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
         Identity-68                  [-1, 512]               0
           ResNet-69                  [-1, 512]               0
           Conv2d-70         [-1, 64, 112, 112]           9,408
      BatchNorm2d-71         [-1, 64, 112, 112]             128
             ReLU-72         [-1, 64, 112, 112]               0
        MaxPool2d-73           [-1, 64, 56, 56]               0
           Conv2d-74           [-1, 64, 56, 56]          36,864
      BatchNorm2d-75           [-1, 64, 56, 56]             128
             ReLU-76           [-1, 64, 56, 56]               0
           Conv2d-77           [-1, 64, 56, 56]          36,864
      BatchNorm2d-78           [-1, 64, 56, 56]             128
             ReLU-79           [-1, 64, 56, 56]               0
       BasicBlock-80           [-1, 64, 56, 56]               0
           Conv2d-81           [-1, 64, 56, 56]          36,864
      BatchNorm2d-82           [-1, 64, 56, 56]             128
             ReLU-83           [-1, 64, 56, 56]               0
           Conv2d-84           [-1, 64, 56, 56]          36,864
      BatchNorm2d-85           [-1, 64, 56, 56]             128
             ReLU-86           [-1, 64, 56, 56]               0
       BasicBlock-87           [-1, 64, 56, 56]               0
           Conv2d-88          [-1, 128, 28, 28]          73,728
      BatchNorm2d-89          [-1, 128, 28, 28]             256
             ReLU-90          [-1, 128, 28, 28]               0
           Conv2d-91          [-1, 128, 28, 28]         147,456
      BatchNorm2d-92          [-1, 128, 28, 28]             256
           Conv2d-93          [-1, 128, 28, 28]           8,192
      BatchNorm2d-94          [-1, 128, 28, 28]             256
             ReLU-95          [-1, 128, 28, 28]               0
       BasicBlock-96          [-1, 128, 28, 28]               0
           Conv2d-97          [-1, 128, 28, 28]         147,456
      BatchNorm2d-98          [-1, 128, 28, 28]             256
             ReLU-99          [-1, 128, 28, 28]               0
          Conv2d-100          [-1, 128, 28, 28]         147,456
     BatchNorm2d-101          [-1, 128, 28, 28]             256
            ReLU-102          [-1, 128, 28, 28]               0
      BasicBlock-103          [-1, 128, 28, 28]               0
          Conv2d-104          [-1, 256, 14, 14]         294,912
     BatchNorm2d-105          [-1, 256, 14, 14]             512
            ReLU-106          [-1, 256, 14, 14]               0
          Conv2d-107          [-1, 256, 14, 14]         589,824
     BatchNorm2d-108          [-1, 256, 14, 14]             512
          Conv2d-109          [-1, 256, 14, 14]          32,768
     BatchNorm2d-110          [-1, 256, 14, 14]             512
            ReLU-111          [-1, 256, 14, 14]               0
      BasicBlock-112          [-1, 256, 14, 14]               0
          Conv2d-113          [-1, 256, 14, 14]         589,824
     BatchNorm2d-114          [-1, 256, 14, 14]             512
            ReLU-115          [-1, 256, 14, 14]               0
          Conv2d-116          [-1, 256, 14, 14]         589,824
     BatchNorm2d-117          [-1, 256, 14, 14]             512
            ReLU-118          [-1, 256, 14, 14]               0
      BasicBlock-119          [-1, 256, 14, 14]               0
          Conv2d-120            [-1, 512, 7, 7]       1,179,648
     BatchNorm2d-121            [-1, 512, 7, 7]           1,024
            ReLU-122            [-1, 512, 7, 7]               0
          Conv2d-123            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-124            [-1, 512, 7, 7]           1,024
          Conv2d-125            [-1, 512, 7, 7]         131,072
     BatchNorm2d-126            [-1, 512, 7, 7]           1,024
            ReLU-127            [-1, 512, 7, 7]               0
      BasicBlock-128            [-1, 512, 7, 7]               0
          Conv2d-129            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-130            [-1, 512, 7, 7]           1,024
            ReLU-131            [-1, 512, 7, 7]               0
          Conv2d-132            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-133            [-1, 512, 7, 7]           1,024
            ReLU-134            [-1, 512, 7, 7]               0
      BasicBlock-135            [-1, 512, 7, 7]               0
AdaptiveAvgPool2d-136            [-1, 512, 1, 1]               0
        Identity-137                  [-1, 512]               0
          ResNet-138                  [-1, 512]               0
          Conv2d-139         [-1, 64, 112, 112]           9,408
     BatchNorm2d-140         [-1, 64, 112, 112]             128
            ReLU-141         [-1, 64, 112, 112]               0
       MaxPool2d-142           [-1, 64, 56, 56]               0
          Conv2d-143           [-1, 64, 56, 56]          36,864
     BatchNorm2d-144           [-1, 64, 56, 56]             128
            ReLU-145           [-1, 64, 56, 56]               0
          Conv2d-146           [-1, 64, 56, 56]          36,864
     BatchNorm2d-147           [-1, 64, 56, 56]             128
            ReLU-148           [-1, 64, 56, 56]               0
      BasicBlock-149           [-1, 64, 56, 56]               0
          Conv2d-150           [-1, 64, 56, 56]          36,864
     BatchNorm2d-151           [-1, 64, 56, 56]             128
            ReLU-152           [-1, 64, 56, 56]               0
          Conv2d-153           [-1, 64, 56, 56]          36,864
     BatchNorm2d-154           [-1, 64, 56, 56]             128
            ReLU-155           [-1, 64, 56, 56]               0
      BasicBlock-156           [-1, 64, 56, 56]               0
          Conv2d-157          [-1, 128, 28, 28]          73,728
     BatchNorm2d-158          [-1, 128, 28, 28]             256
            ReLU-159          [-1, 128, 28, 28]               0
          Conv2d-160          [-1, 128, 28, 28]         147,456
     BatchNorm2d-161          [-1, 128, 28, 28]             256
          Conv2d-162          [-1, 128, 28, 28]           8,192
     BatchNorm2d-163          [-1, 128, 28, 28]             256
            ReLU-164          [-1, 128, 28, 28]               0
      BasicBlock-165          [-1, 128, 28, 28]               0
          Conv2d-166          [-1, 128, 28, 28]         147,456
     BatchNorm2d-167          [-1, 128, 28, 28]             256
            ReLU-168          [-1, 128, 28, 28]               0
          Conv2d-169          [-1, 128, 28, 28]         147,456
     BatchNorm2d-170          [-1, 128, 28, 28]             256
            ReLU-171          [-1, 128, 28, 28]               0
      BasicBlock-172          [-1, 128, 28, 28]               0
          Conv2d-173          [-1, 256, 14, 14]         294,912
     BatchNorm2d-174          [-1, 256, 14, 14]             512
            ReLU-175          [-1, 256, 14, 14]               0
          Conv2d-176          [-1, 256, 14, 14]         589,824
     BatchNorm2d-177          [-1, 256, 14, 14]             512
          Conv2d-178          [-1, 256, 14, 14]          32,768
     BatchNorm2d-179          [-1, 256, 14, 14]             512
            ReLU-180          [-1, 256, 14, 14]               0
      BasicBlock-181          [-1, 256, 14, 14]               0
          Conv2d-182          [-1, 256, 14, 14]         589,824
     BatchNorm2d-183          [-1, 256, 14, 14]             512
            ReLU-184          [-1, 256, 14, 14]               0
          Conv2d-185          [-1, 256, 14, 14]         589,824
     BatchNorm2d-186          [-1, 256, 14, 14]             512
            ReLU-187          [-1, 256, 14, 14]               0
      BasicBlock-188          [-1, 256, 14, 14]               0
          Conv2d-189            [-1, 512, 7, 7]       1,179,648
     BatchNorm2d-190            [-1, 512, 7, 7]           1,024
            ReLU-191            [-1, 512, 7, 7]               0
          Conv2d-192            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-193            [-1, 512, 7, 7]           1,024
          Conv2d-194            [-1, 512, 7, 7]         131,072
     BatchNorm2d-195            [-1, 512, 7, 7]           1,024
            ReLU-196            [-1, 512, 7, 7]               0
      BasicBlock-197            [-1, 512, 7, 7]               0
          Conv2d-198            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-199            [-1, 512, 7, 7]           1,024
            ReLU-200            [-1, 512, 7, 7]               0
          Conv2d-201            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-202            [-1, 512, 7, 7]           1,024
            ReLU-203            [-1, 512, 7, 7]               0
      BasicBlock-204            [-1, 512, 7, 7]               0
AdaptiveAvgPool2d-205            [-1, 512, 1, 1]               0
        Identity-206                  [-1, 512]               0
          ResNet-207                  [-1, 512]               0
          Linear-208                    [-1, 2]           3,074
================================================================
Total params: 33,532,610
Trainable params: 33,532,610
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4096.00
Forward/backward pass size (MB): 188.38
Params size (MB): 127.92
Estimated Total Size (MB): 4412.30
----------------------------------------------------------------

DataSet 部分

由于我们是对同一个物体从三个视角进行融合,所以只需要像普通图像分类那样准备DataSet即可,需要保证三个数据来源的顺序相同(确保每次取得的图片是同一个物体);特别值得注意的是,我发现在普通CNN里面的图像增强方式在这里特别不适合,起初模型的效果很糟糕,在测试集上的表现是,精度会上升,但是波动性很大,且最终训练出来的指标(AUC)小于0.5,我认为可能的问题是,数据发生过拟合时,(AUC)会降低,在我去除掉各种各样的图像增强后,模型回归正常,精度上升且AUC不在异常,但是训练到后面,即数据开始由过拟合的迹象时,AUC开始下降。过拟合和过多的图像增广可能都使的模型效果差,由于Pytorch框架里面的各种图像增广是随机的,导致在三个维度上的效果不一样,差异性很大,另一个原因可能是在这批数据及上,图像增广使得图像发生严重的变形。

from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np


class MyDataset(Dataset):
    def __init__(self, root_path, mode='train'):
        """
        Args:
            :param root_path: 期的根目录
            :param mode: 选择使用train还是val数据集
        """
        assert os.path.exists(root_path), "dataset root: {} does not exist.".format(root_path)
        self.axis_list = ['Axial', 'Coronal', 'Sagittal']
        self.mode = mode
        self.root_path = root_path

        self.dataset0 = os.path.join(self.root_path, self.axis_list[0], self.mode)
        self.dataset1 = os.path.join(self.root_path, self.axis_list[1], self.mode)
        self.dataset2 = os.path.join(self.root_path, self.axis_list[2], self.mode)
        self.img_classes = ['high', 'low']

        if mode == 'train':
            self.train_img0 = np.array([])
            self.train_img1 = np.array([])
            self.train_img2 = np.array([])
            self.train_label0 = np.array([])
            self.train_label1 = np.array([])
            self.train_label2 = np.array([])
            for cla in self.img_classes:
                cla_path0 = os.path.join(self.dataset0, cla)
                cla_path1 = os.path.join(self.dataset1, cla)
                cla_path2 = os.path.join(self.dataset2, cla)
                for img in os.listdir(cla_path0):
                    self.train_img0 = np.append(self.train_img0, os.path.join(cla_path0, img))
                    self.train_label0 = np.append(self.train_label0, cla)

                for img in os.listdir(cla_path1):
                    self.train_img1 = np.append(self.train_img1, os.path.join(cla_path1, img))
                    self.train_label1 = np.append(self.train_label1, cla)

                for img in os.listdir(cla_path2):
                    self.train_img2 = np.append(self.train_img2, os.path.join(cla_path2, img))
                    self.train_label2 = np.append(self.train_label2, cla)

            self.img_arr0 = self.train_img0
            self.img_arr1 = self.train_img1
            self.img_arr2 = self.train_img2
            self.label0 = self.train_label0
            self.label1 = self.train_label1
            self.label2 = self.train_label2


        if mode == 'val':
            self.val_img0 = np.array([])
            self.val_img1 = np.array([])
            self.val_img2 = np.array([])
            self.val_label0 = np.array([])
            self.val_label1 = np.array([])
            self.val_label2 = np.array([])
            for cla in self.img_classes:
                cla_path0 = os.path.join(self.dataset0, cla)
                cla_path1 = os.path.join(self.dataset1, cla)
                cla_path2 = os.path.join(self.dataset2, cla)
                for img in os.listdir(cla_path0):
                    self.val_img0 = np.append(self.val_img0, os.path.join(cla_path0, img))
                    self.val_label0 = np.append(self.val_label0, cla)

                for img in os.listdir(cla_path1):
                    self.val_img1 = np.append(self.val_img1, os.path.join(cla_path1, img))
                    self.val_label1 = np.append(self.val_label1, cla)

                for img in os.listdir(cla_path2):
                    self.val_img2 = np.append(self.val_img2, os.path.join(cla_path2, img))
                    self.val_label2 = np.append(self.val_label2, cla)

            self.img_arr0 = self.val_img0
            self.img_arr1 = self.val_img1
            self.img_arr2 = self.val_img2
            self.label0 = self.val_label0
            self.label1 = self.val_label1
            self.label2 = self.val_label2

    def __getitem__(self, item):
        img0 = Image.open(self.img_arr0[item])
        img1 = Image.open(self.img_arr1[item])
        img2 = Image.open(self.img_arr2[item])

        if self.mode == 'train':
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                # transforms.RandomResizedCrop(224, scale=(0.9, 1.0), ratio=(0.9, 1.0)),
                # transforms.RandomHorizontalFlip(p=0.5),
                # transforms.ColorJitter(brightness=(0.7, 1.3), contrast=(0.7, 1.3)),
                # transforms.RandomRotation(degrees=45),
                transforms.ToTensor(),
                # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        if self.mode == 'val':
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
                # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        img0 = transform(img0)
        img1 = transform(img1)
        img2 = transform(img2)
        if self.label0[item] == 'high':
            label = 1
        elif self.label0[item] == 'low':
            label = 0

        return img0, img1, img2, label

    def __len__(self):
        return len(self.img_arr0)

训练部分

训练部分的内容就比较偏常规一点,由于我们需要对模型进行综合评估,所以不单是从准确率来评价,需要加上各章评估指标;然而,这次小实践,大部分时间都是花在了训练调参以及寻找问题当中,问题包括上面提到的数据和模型方面的问题,而调参也是非常麻烦;在这批数据集中,因为数据量非常小,而我采用的又是很深的resnet模型,所以非常容易过拟合,所以最开始的效果是非常差的,经过摸索和网上搜索发现:将L1正则化参数和学习率同时设置的很大效果会好很多,最开始L1正则化很小的时候,当学习率偏大基本一个epoch就会过拟合,而当学习率偏小时,模型又基本不会有性能的提升,等到后面将L1设置很大效果很好,当同时也不能过大,否则模型不会提升

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from sklearn.metrics import roc_curve, auc, confusion_matrix
from Draw_Curve import acc_loss_curve
from my_model import MyEnsemble
from torchsummary import summary
import numpy as np
from torch.utils.data import DataLoader
from mydataset import MyDataset


def train_step():
    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    print(f'using {device} .\n')
    train_dataset = MyDataset(root_path='G:\WPJ\Pythonproject\HCC_117\dataset\PS', mode='train')
    train_loader = DataLoader(train_dataset,
                              batch_size=64,
                              shuffle=True,
                              num_workers=0)
    val_dataset = MyDataset(root_path='G:\WPJ\Pythonproject\HCC_117\dataset\PS', mode='val')
    val_loader = DataLoader(val_dataset,
                            batch_size=64,
                            shuffle=False,
                            num_workers=0)
    loss_function: CrossEntropyLoss = nn.CrossEntropyLoss()
    model = MyEnsemble()
    model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=3e-4, weight_decay=2e-1)
    # DMQ   lr=3e-4, weight_decay=2e-1  且可以减少图像增广 lr=4e-4, weight_decay=2e-1
    # MMQ lr=3e-4, weight_decay=1e-2     lr=6e-4, weight_decay=2e-1
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=4,
                                                           threshold=1e-5, threshold_mode='rel', min_lr=2e-6,
                                                           verbose=True)
    #      2e-6              2e-2
    #      4e-6              1e-2
    #      1e-6              1e-2
    #      2e-6              9e-3
    #      最好的效果是 2e-6 + 9e-3  增强:左右翻转0.5  亮度和对比度
    epochs = 200
    multi = 0.
    best_acc = 0.
    best_loss = 2.0
    iterition = 0
    acc_list = []
    loss_list = []
    for epoch in range(epochs):
        train_total, train_correct, train_loss = 0, 0., 0.
        train_all_label, train_all_pred, train_all_prescore = np.array([]), np.array([]), np.array([])
        for step, batch in enumerate(train_loader):
            model.train()
            train_img0, train_img1, train_img2, train_label = batch
            train_img0, train_img1, train_img2, train_label = train_img0.to(device), train_img1.to(
                device), train_img2.to(device), train_label.to(device)
            optimizer.zero_grad()
            outputs = model(train_img0, train_img1, train_img2)
            loss = loss_function(outputs, train_label)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_pre = torch.argmax(outputs.data, 1)
            train_prescore = torch.max(outputs.data, 1)[0]
            train_correct += (train_label == train_pre).sum().item()
            train_total += train_label.size(0)
            train_all_pred = np.append(train_all_pred, np.asarray(train_pre.to('cpu')))
            train_all_prescore = np.append(train_all_prescore, np.asarray(train_prescore.to('cpu')))
            train_all_label = np.append(train_all_label, np.asarray(train_label.to('cpu')))
            del train_img0, train_img1, train_img2, train_label
            torch.cuda.empty_cache()

            if (step + 1) % 30 == 0:
                model.eval()
                val_correct = 0.
                val_loss = 0.
                all_pred, all_prescore, all_label = np.array([]), np.array([]), np.array([])
                with torch.no_grad():
                    for val_img0_, val_img1_, val_img2_, val_label_ in val_loader:
                        val_img0, val_img1, val_img2, val_label = val_img0_.to(device), val_img1_.to(device), val_img2_.to(device), val_label_.to(device)
                        output = model(val_img0, val_img1, val_img2)
                        val_loss += loss_function(output, val_label).item()
                        val_pred = output.argmax(dim=1)
                        val_prescore = torch.max(output.data, 1)[0]
                        val_correct += (val_pred == val_label).sum().item()
                        all_pred = np.append(all_pred, np.asarray(val_pred.to('cpu')))
                        all_prescore = np.append(all_prescore, np.asarray(val_prescore.to('cpu')))
                        all_label = np.append(all_label, np.asarray(val_label.to('cpu')))
                    fp_list, tp_list, thresholes = roc_curve(all_label, all_prescore, pos_label=1)
                    roc_auc = auc(fp_list, tp_list)
                    total_size = len(val_dataset)
                    val_loss /= len(val_loader)
                    val_acc = val_correct / total_size
                    acc_list.append(val_acc)
                    loss_list.append(val_loss)
                    multi_current = 0.5*val_acc + 0.5*roc_auc
                    if multi_current > multi  and val_acc > 0.7:
                        multi = multi_current
                        torch.save(model.state_dict(), './result/PS_feature_fusion_multi.pth')
                    if val_acc > best_acc and roc_auc > 0.8:
                        best_acc = val_acc
                        torch.save(model.state_dict(), './result/PS_feature_fusion_acc.pth')
                    if val_loss < best_loss:
                        best_loss = val_loss
                        torch.save(model.state_dict(), './result/PS_feature_fusion_loss.pth')
                    torch.save(model.state_dict(), './result/PS_feature_fusion_last.pth')
                iterition += 30
                print(f'epoch:[{epoch + 1}/{epochs}],val:[loss {val_loss:.3f}, acc {val_acc:.3f}, auc {roc_auc:.3f}]\n')

        scheduler.step(val_loss)
        train_acc = train_correct / train_total
        train_loss = train_loss / len(train_loader)
        print(f'                            epoch:[{epoch + 1}/{epochs}],train:[loss {train_loss:.3f}, acc {train_acc:.3f}]\n'
              )
        train_fp_list, train_tp_list, train_thresholes = roc_curve(train_all_label, train_all_prescore)
        train_roc_auc = auc(train_fp_list, train_tp_list)
        print(f'auc: {train_roc_auc:.3f}')
    acc_loss_curve(loss=np.asarray(loss_list), acc=np.asarray(acc_list), iterition=iterition, epochs=epochs,
                   step_size=30)



if __name__ == '__main__':
    train_step()
    print('Finished train!')


预测部分

其实也就是将验证集数据拿来测试模型的性能,将各种指标用上,包括T-SNE等,关于T-SNE,需要将模型经过在分类前将数据输入到t-sne降维器中,可以选择先保存然后使用t-sne时调用

from sklearn.metrics import roc_curve, auc, f1_score, confusion_matrix
from torch.utils.data import DataLoader
import numpy as np
import torch
from Draw_Curve import roc_auc_curve
from mydataset import MyDataset
from my_model import MyEnsemble
from TSNE_curve import tsne_curve


def prediction():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(f'using {device}.\n')
    val_dataset = MyDataset(root_path='G:\WPJ\Pythonproject\HCC_117\dataset\MMQ', mode='val')
    val_loader = DataLoader(val_dataset,
                            batch_size=64,
                            shuffle=False,
                            num_workers=0)
    model = MyEnsemble()
    model.eval()
    model.load_state_dict(torch.load('./result/MMQ_feature_fusion_acc.pth', map_location=device))
    # model.load_state_dict(torch.load('G:\WPJ\Pythonproject\HCC_feature_fusion_axis\\result\\beifen\DMQ_feature_fusion_multi.pth', map_location=device))
    model.to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    all_pre, all_prescore, all_label = np.array([]), np.array([]), np.array([]),
    for batch in val_loader:
        val_img0, val_img1, val_img2, val_label = batch
        val_img0, val_img1, val_img2, val_label = val_img0.to(device), val_img1.to(device), val_img2.to(device), val_label.to(device)
        with torch.no_grad():
            outputs = model(val_img0, val_img1, val_img2, mode='val')
            # print(outputs)
            loss = loss_function(outputs, val_label)
            val_pre = torch.argmax(outputs.data, 1)
            val_prescore = torch.max(outputs.data, 1)[0]
            val_correct += (val_label == val_pre).sum().item()
            val_loss += loss.item()
            val_total += val_label.size(0)
        all_pre = np.append(all_pre, np.array(val_pre.to('cpu')))
        all_prescore = np.append(all_prescore, np.array(val_prescore.to('cpu')))
        all_label = np.append(all_label, np.array(val_label.to('cpu')))
    fp_list, tp_list, thresholes = roc_curve(all_label, all_prescore, pos_label=1)
    roc_auc = auc(fp_list, tp_list)
    f1score = f1_score(all_label, all_pre)
    confusionmatrix = confusion_matrix(all_label, all_pre)
    tn, fp, fn, tp = confusionmatrix.ravel()
    print(confusionmatrix)
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    val_loss = val_loss / len(val_loader)
    val_acc = val_correct / val_total
    print(f'val_loss: {val_loss:.3f}, val_acc: {val_acc:.3f}')
    print(f'auc: {roc_auc:.3f}  f1score: {f1score:.3f}')
    print(f'sensitivity:{sensitivity:.3f}   specificity:{specificity:.3f}')
    roc_auc_curve(fp_list, tp_list, val_acc)
    feature = torch.load('./result/feature.pt')
    tsne_curve(feature, val_label)

    # return val_loss, val_acc, roc_auc, f1score, sensitivity, specificity

if __name__ == '__main__':
    prediction()

新增部分

模型除了上面介绍的conca 还新增了DSN方法和ABAF方法(来自于师兄的论文),大致看了一下旧的内容,基本的问题在于融合使用的是resnet改的模型,对于这个数据来说太深了,此外另一个问题存在于训练技巧

论文名称:Adaptive Multimodal Fusion With Attention
Guided Deep Supervision Net for Grading
Hepatocellular Carcinoma

模型

注意:此处的模型是争对于作者该数据集比较适合的模型,因为本身需要识别的数据的原始大小就是分小,因此对一个小的图片进行几次卷积之后融合是有效的,可根据自己的需求进行修改(保证backbone能够提取到有效的特征就行)
关于模型,其核心的差别在于经过卷积之后的融合方法,因此可以选取合适的CNN作为backbone然后融合

concat

class Model_conca(nn.Module):
    def __init__(self, numclasses=2):
        #######################  DMQ Feature ###########################
        super(Model_conca, self).__init__()
        self.conv_11 = nn.Conv3d(kernel_size=3, in_channels=1, out_channels=32, padding='same')
        self.bn_11 = nn.BatchNorm3d(32)
        self.pool_11 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv_12 = nn.Conv3d(kernel_size=3, in_channels=32, out_channels=64, padding='same')
        self.bn_12 = nn.BatchNorm3d(64)
        self.pool_12 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.fc_11 = nn.Linear(7 * 7 * 7 * 64, 500)
        self.dropout_11 = nn.Dropout(0.5)
        self.fc_12 = nn.Linear(500, 50)

        ######################### MMQ Feature ###############################

        self.conv_21 = nn.Conv3d(kernel_size=3, in_channels=1, out_channels=32, padding='same')
        self.bn_21 = nn.BatchNorm3d(32)
        self.pool_21 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv_22 = nn.Conv3d(kernel_size=3, in_channels=32, out_channels=64, padding='same')
        self.bn_22 = nn.BatchNorm3d(64)
        self.pool_22 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.fc_21 = nn.Linear(7 * 7 * 7 * 64, 500)
        self.dropout_21 = nn.Dropout(0.5)
        self.fc_22 = nn.Linear(500, 50)

        ############################ PS ####################################

        self.conv_31 = nn.Conv3d(kernel_size=3, in_channels=1, out_channels=32, padding='same')
        self.bn_31 = nn.BatchNorm3d(32)
        self.pool_31 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv_32 = nn.Conv3d(kernel_size=3, in_channels=32, out_channels=64, padding='same')
        self.bn_32 = nn.BatchNorm3d(64)
        self.pool_32 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.fc_31 = nn.Linear(7 * 7 * 7 * 64, 500)
        self.dropout_31 = nn.Dropout(0.5)
        self.fc_32 = nn.Linear(500, 50)

        ############################ Fusion ###########################3
        self.f1 = nn.Linear(150, 50)
        self.f2 = nn.Linear(50, numclasses)

    def forward(self, x1, x2, x3):
        ############# Axial Feature ##################
        conv_out_11 = F.relu(self.bn_11(self.conv_11(x1)))
        pool_out_11 = self.pool_11(conv_out_11)
        conv_out_12 = F.relu(self.bn_12(self.conv_12(pool_out_11)))
        pool_out_12 = self.pool_12(conv_out_12)
        pool_out_12 = pool_out_12.view(-1, 7 * 7 * 7 * 64)
        fc_out_11 = F.relu(self.fc_11(pool_out_12))
        dropout_11 = self.dropout_11(fc_out_11)
        fc_out_12 = F.relu(self.fc_12(dropout_11))

            ################## Coronal Feature #############

        conv_out_21 = F.relu(self.bn_21(self.conv_21(x2)))
        pool_out_21 = self.pool_21(conv_out_21)
        conv_out_22 = F.relu(self.bn_22(self.conv_22(pool_out_21)))
        pool_out_22 = self.pool_22(conv_out_22)
        pool_out_22 = pool_out_22.view(-1, 7 * 7 * 7 * 64)
        fc_out_21 = F.relu(self.fc_21(pool_out_22))
        dropout_21 = self.dropout_21(fc_out_21)
        fc_out_22 = F.relu(self.fc_22(dropout_21))

            ############################## Sagittal ######################

        conv_out_31 = F.relu(self.bn_31(self.conv_31(x3)))
        pool_out_31 = self.pool_31(conv_out_31)
        conv_out_32 = F.relu(self.bn_32(self.conv_32(pool_out_31)))
        pool_out_32 = self.pool_32(conv_out_32)
        pool_out_32 = pool_out_32.view(-1, 7 * 7 * 7 * 64)
        fc_out_31 = F.relu(self.fc_31(pool_out_32))
        dropout_31 = self.dropout_31(fc_out_31)
        fc_out_32 = F.relu(self.fc_32(dropout_31))

            ######################### Fusion ################################
        feature_cat = torch.cat((fc_out_12, fc_out_22, fc_out_32), dim=1)
        fc_out_f1 = F.relu(self.f1(feature_cat))
        mid = self.f2(fc_out_f1)
        prediction = torch.softmax(mid, dim=1)

        return prediction

DSN

对于DSN,就是除了将concat融合后的整体损失要进行法相传播之外,还要将三个面各自对应的分类损失加权相加后进行反向传播,这里的权重取三个面相等

class Model_DSN(nn.Module):
    def __init__(self, numclasses=2):
        #######################  DMQ Feature ###########################
        super(Model_DSN, self).__init__()
        self.conv_11 = nn.Conv3d(kernel_size=3, in_channels=1, out_channels=32, padding='same')
        self.bn_11 = nn.BatchNorm3d(32)
        self.pool_11 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv_12 = nn.Conv3d(kernel_size=3, in_channels=32, out_channels=64, padding='same')
        self.bn_12 = nn.BatchNorm3d(64)
        self.pool_12 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.fc_11 = nn.Linear(7 * 7 * 7 * 64, 500)
        self.dropout_11 = nn.Dropout(0.5)
        self.fc_12 = nn.Linear(500, 50)
        self.fc_13 = nn.Linear(50, 2)

        ######################### MMQ Feature ###############################

        self.conv_21 = nn.Conv3d(kernel_size=3, in_channels=1, out_channels=32, padding='same')
        self.bn_21 = nn.BatchNorm3d(32)
        self.pool_21 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv_22 = nn.Conv3d(kernel_size=3, in_channels=32, out_channels=64, padding='same')
        self.bn_22 = nn.BatchNorm3d(64)
        self.pool_22 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.fc_21 = nn.Linear(7 * 7 * 7 * 64, 500)
        self.dropout_21 = nn.Dropout(0.5)
        self.fc_22 = nn.Linear(500, 50)
        self.fc_23 = nn.Linear(50, 2)

        ############################ PS ####################################

        self.conv_31 = nn.Conv3d(kernel_size=3, in_channels=1, out_channels=32, padding='same')
        self.bn_31 = nn.BatchNorm3d(32)
        self.pool_31 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv_32 = nn.Conv3d(kernel_size=3, in_channels=32, out_channels=64, padding='same')
        self.bn_32 = nn.BatchNorm3d(64)
        self.pool_32 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.fc_31 = nn.Linear(7 * 7 * 7 * 64, 500)
        self.dropout_31 = nn.Dropout(0.5)
        self.fc_32 = nn.Linear(500, 50)
        self.fc_33 = nn.Linear(50, 2)

        ############################ Fusion ###########################3
        self.f1 = nn.Linear(150, 50)
        self.f2 = nn.Linear(50, numclasses)

    def forward(self, x1, x2, x3, label):
        criterion = nn.CrossEntropyLoss()
        ############# Axial Feature ##################
        conv_out_11 = F.relu(self.bn_11(self.conv_11(x1)))
        pool_out_11 = self.pool_11(conv_out_11)
        conv_out_12 = F.relu(self.bn_12(self.conv_12(pool_out_11)))
        pool_out_12 = self.pool_12(conv_out_12)
        pool_out_12 = pool_out_12.view(-1, 7 * 7 * 7 * 64)
        fc_out_11 = F.relu(self.fc_11(pool_out_12))
        dropout_11 = self.dropout_11(fc_out_11)
        fc_out_12 = F.relu(self.fc_12(dropout_11))

        mid1 = torch.softmax(self.fc_13(fc_out_12), dim=1)
        L1 = torch.mean(criterion(mid1, label))

            ################## Coronal Feature #############

        conv_out_21 = F.relu(self.bn_21(self.conv_21(x2)))
        pool_out_21 = self.pool_21(conv_out_21)
        conv_out_22 = F.relu(self.bn_22(self.conv_22(pool_out_21)))
        pool_out_22 = self.pool_22(conv_out_22)
        pool_out_22 = pool_out_22.view(-1, 7 * 7 * 7 * 64)
        fc_out_21 = F.relu(self.fc_21(pool_out_22))
        dropout_21 = self.dropout_21(fc_out_21)
        fc_out_22 = F.relu(self.fc_22(dropout_21))

        mid2 = torch.softmax(self.fc_23(fc_out_22), dim=1)
        L2 = torch.mean(criterion(mid2, label))

            ############################## Sagittal ######################

        conv_out_31 = F.relu(self.bn_31(self.conv_31(x3)))
        pool_out_31 = self.pool_31(conv_out_31)
        conv_out_32 = F.relu(self.bn_32(self.conv_32(pool_out_31)))
        pool_out_32 = self.pool_32(conv_out_32)
        pool_out_32 = pool_out_32.view(-1, 7 * 7 * 7 * 64)
        fc_out_31 = F.relu(self.fc_31(pool_out_32))
        dropout_31 = self.dropout_31(fc_out_31)
        fc_out_32 = F.relu(self.fc_32(dropout_31))

        mid3 = torch.softmax(self.fc_33(fc_out_32), dim=1)
        L3 = torch.mean(criterion(mid3, label))

            ######################### Fusion ################################
        feature_cat = torch.cat((fc_out_12, fc_out_22, fc_out_32), dim=1)
        fc_out_f1 = F.relu(self.f1(feature_cat))
        mid = self.f2(fc_out_f1)
        prediction = torch.softmax(mid, dim=1)

        return prediction, L3 + L2 +L1

ABAF

对于ABAF方法,关键点在于在DSN的基础上,对三个单相的加权平均损失的权重的赋值,这里采用的是基于注意力机制的自适应方法。

class Model_ABAF(nn.Module):
    def __init__(self, numclasses=2):
        #######################  DMQ Feature ###########################
        super(Model_ABAF, self).__init__()
        self.conv_11 = nn.Conv3d(kernel_size=3, in_channels=1, out_channels=32, padding='same')
        self.bn_11 = nn.BatchNorm3d(32)
        self.pool_11 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv_12 = nn.Conv3d(kernel_size=3, in_channels=32, out_channels=64, padding='same')
        self.bn_12 = nn.BatchNorm3d(64)
        self.pool_12 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.fc_11 = nn.Linear(7 * 7 * 7 * 64, 500)
        self.dropout_11 = nn.Dropout(0.5)
        self.fc_12 = nn.Linear(500, 50)
        self.fc_13 = nn.Linear(50, 2)

        ######################### MMQ Feature ###############################

        self.conv_21 = nn.Conv3d(kernel_size=3, in_channels=1, out_channels=32, padding='same')
        self.bn_21 = nn.BatchNorm3d(32)
        self.pool_21 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv_22 = nn.Conv3d(kernel_size=3, in_channels=32, out_channels=64, padding='same')
        self.bn_22 = nn.BatchNorm3d(64)
        self.pool_22 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.fc_21 = nn.Linear(7 * 7 * 7 * 64, 500)
        self.dropout_21 = nn.Dropout(0.5)
        self.fc_22 = nn.Linear(500, 50)
        self.fc_23 = nn.Linear(50, 2)

        ############################ PS ####################################

        self.conv_31 = nn.Conv3d(kernel_size=3, in_channels=1, out_channels=32, padding='same')
        self.bn_31 = nn.BatchNorm3d(32)
        self.pool_31 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv_32 = nn.Conv3d(kernel_size=3, in_channels=32, out_channels=64, padding='same')
        self.bn_32 = nn.BatchNorm3d(64)
        self.pool_32 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.fc_31 = nn.Linear(7 * 7 * 7 * 64, 500)
        self.dropout_31 = nn.Dropout(0.5)
        self.fc_32 = nn.Linear(500, 50)
        self.fc_33 = nn.Linear(50, 2)

        ############################ Fusion ###########################3
        self.f1 = nn.Linear(150, 50)
        self.f2 = nn.Linear(50, numclasses)

        self.fusion_11 = nn.Linear(150, 50)
        self.fusion_21 = nn.Linear(150, 50)
        self.fusion_31 = nn.Linear(150, 50)

    def forward(self, x1, x2, x3, label):
        criterion = nn.CrossEntropyLoss()
        ############# Axial Feature ##################
        conv_out_11 = F.relu(self.bn_11(self.conv_11(x1)))
        pool_out_11 = self.pool_11(conv_out_11)
        conv_out_12 = F.relu(self.bn_12(self.conv_12(pool_out_11)))
        pool_out_12 = self.pool_12(conv_out_12)
        pool_out_12 = pool_out_12.view(-1, 7 * 7 * 7 * 64)
        fc_out_11 = F.relu(self.fc_11(pool_out_12))
        dropout_11 = self.dropout_11(fc_out_11)
        fc_out_12 = F.relu(self.fc_12(dropout_11))

        mid1 = torch.softmax(self.fc_13(fc_out_12), dim=1)
        L1 = torch.mean(criterion(mid1, label))

        ################## Coronal Feature #############

        conv_out_21 = F.relu(self.bn_21(self.conv_21(x2)))
        pool_out_21 = self.pool_21(conv_out_21)
        conv_out_22 = F.relu(self.bn_22(self.conv_22(pool_out_21)))
        pool_out_22 = self.pool_22(conv_out_22)
        pool_out_22 = pool_out_22.view(-1, 7 * 7 * 7 * 64)
        fc_out_21 = F.relu(self.fc_21(pool_out_22))
        dropout_21 = self.dropout_21(fc_out_21)
        fc_out_22 = F.relu(self.fc_22(dropout_21))

        mid2 = torch.softmax(self.fc_23(fc_out_22), dim=1)
        L2 = torch.mean(criterion(mid2, label))

        ############################## Sagittal ######################

        conv_out_31 = F.relu(self.bn_31(self.conv_31(x3)))
        pool_out_31 = self.pool_31(conv_out_31)
        conv_out_32 = F.relu(self.bn_32(self.conv_32(pool_out_31)))
        pool_out_32 = self.pool_32(conv_out_32)
        pool_out_32 = pool_out_32.view(-1, 7 * 7 * 7 * 64)
        fc_out_31 = F.relu(self.fc_31(pool_out_32))
        dropout_31 = self.dropout_31(fc_out_31)
        fc_out_32 = F.relu(self.fc_32(dropout_31))

        mid3 = torch.softmax(self.fc_33(fc_out_32), dim=1)
        L3 = torch.mean(criterion(mid3, label))
        ################################################################33
        concat1 = torch.cat((fc_out_12, fc_out_22, fc_out_32), dim=1)
        f11 = self.fusion_11(concat1)
        f21 = self.fusion_21(concat1)
        f31 = self.fusion_31(concat1)

        alpha1 = torch.sigmoid(f11)
        alpha2 = torch.sigmoid(f21)
        alpha3 = torch.sigmoid(f31)

        f1 = torch.multiply(alpha1, fc_out_12)
        f2 = torch.multiply(alpha2, fc_out_22)
        f3 = torch.multiply(alpha3, fc_out_32)

        f14 = torch.sum(f1)
        f24 = torch.sum(f2)
        f34 = torch.sum(f3)

        f4 = torch.stack([f14, f24, f34])
        beta = torch.softmax(f4, dim=0)

        f13 = beta[0]
        f23 = beta[1]
        f33 = beta[2]
        
        ######################### Fusion ################################
        feature_cat = torch.cat((fc_out_12, fc_out_22, fc_out_32), dim=1)
        fc_out_f1 = F.relu(self.f1(feature_cat))
        mid = self.f2(fc_out_f1)
        prediction = torch.softmax(mid, dim=1)

        return prediction, f13 * L1 + f23 * L2 + f33 * L3

技巧即问题

而对于技巧问题就是,当数据很小的时候,我们可以使用迭代法:每迭代10次时,对模型进行一次验证,又由于小数据集的结果波动往往会很大,因此我们将对结果进行指数滑动平均后取最大值,进行多次训练后取滑动平均最大值的平均;对于训练技巧,不可以一味的追求过大的L2正则参数,否则将导致模型训练不动或者得到十分差的效果,对于学利率策略,建议采用随训练次数衰减的方法,初始时设置一个较大一点的学习率,随着训练的迭代,学习率逐渐减小有利于提高效率的同时提高模型性能。

滑动平均

对于滑动平均的解释和实现,请自行参考其它文章

def EMA_smooth(points, beta=0.7, bias=False):
    smoothed_points = []
    t = 0
    for point in points:
        t += 1
        if smoothed_points:
            previous = smoothed_points[-1]
            if bias:
                smoothed_points.append((previous * beta + point * (1 - beta)) / (1 - beta ** t))
            else:
                smoothed_points.append(previous * beta + point * (1 - beta))
        else:
            smoothed_points.append(point)
    return smoothed_points
  • 8
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值