深度学习 Day27——J6ResNeXt-50实战解析


前言

关键字: pytorch实现ResNeXt50详解算法,tensorflow实现ResNeXt50详解算法,ResNeXt50详解

1 我的环境

  • 电脑系统:Windows 11
  • 语言环境:python 3.8.6
  • 编译器:pycharm2020.2.3
  • 深度学习环境:
    torch == 1.9.1+cu111
    torchvision == 0.10.1+cu111
    TensorFlow 2.10.1
  • 显卡:NVIDIA GeForce RTX 4070

2 pytorch实现ResNeXt50算法

2.1 前期准备

2.1.1 引入库


import torch
import torch.nn as nn
import time
import copy
from torchvision import transforms, datasets
from pathlib import Path
from PIL import Image
import torchsummary as summary
import torch.nn.functional as F
from collections import OrderedDict
import re
import torch.utils.model_zoo as model_zoo
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率
import warnings

warnings.filterwarnings('ignore')  # 忽略一些warning内容,无需打印

2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU)

"""前期准备-设置GPU"""
# 如果设备上支持GPU就使用GPU,否则使用CPU
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 print("Using {} device".format(device))

输出

Using cuda device

2.1.3 导入数据

'''前期工作-导入数据'''
data_dir = r"D:\DeepLearning\data\monkeypox_recognition"
data_dir = Path(data_dir)

data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[-1] for path in data_paths]
print(classeNames)

输出

['Monkeypox', 'Others']

2.1.4 可视化数据

'''前期工作-可视化数据'''
subfolder = Path(data_dir) / "Monkeypox"
image_files = list(p.resolve() for p in subfolder.glob('*') if p.suffix in [".jpg", ".png", ".jpeg"])
plt.figure(figsize=(10, 6))
for i in range(len(image_files[:12])):
    image_file = image_files[i]
    ax = plt.subplot(3, 4, i + 1)
    img = Image.open(str(image_file))
    plt.imshow(img)
    plt.axis("off")
# 显示图片
plt.tight_layout()
plt.show()

在这里插入图片描述

2.1.4 图像数据变换

'''前期工作-图像数据变换'''
total_datadir = data_dir

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])
total_data = datasets.ImageFolder(total_datadir, transform=train_transforms)
print(total_data)
print(total_data.class_to_idx)

输出

Dataset ImageFolder
    Number of datapoints: 2142
    Root location: D:\DeepLearning\data\monkeypox_recognition
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )
{'Monkeypox': 0, 'Others': 1}

2.1.4 划分数据集

'''前期工作-划分数据集'''
train_size = int(0.8 * len(total_data))  # train_size表示训练集大小,通过将总体数据长度的80%转换为整数得到;
test_size = len(total_data) - train_size  # test_size表示测试集大小,是总体数据长度减去训练集大小。
# 使用torch.utils.data.random_split()方法进行数据集划分。该方法将总体数据total_data按照指定的大小比例([train_size, test_size])随机划分为训练集和测试集,
# 并将划分结果分别赋值给train_dataset和test_dataset两个变量。
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print("train_dataset={}\ntest_dataset={}".format(train_dataset, test_dataset))
print("train_size={}\ntest_size={}".format(train_size, test_size))

输出

train_dataset=<torch.utils.data.dataset.Subset object at 0x000002A96E08E0D0>
test_dataset=<torch.utils.data.dataset.Subset object at 0x000002A96E04E640>
train_size=1713
test_size=429

2.1.4 加载数据

'''前期工作-加载数据'''
batch_size = 32

train_dl = torch.utils.data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=1)

2.1.4 查看数据

'''前期工作-查看数据'''
for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

输出

Shape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

2.2 搭建ResNeXt50模型

"""构建ResNeXt50网络"""


class BN_Conv2d(nn.Module):
    """
    BN_CONV_RELU
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bias=False):
        super(BN_Conv2d, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                      padding=padding, dilation=dilation, groups=groups, bias=bias),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return F.relu(self.seq(x))


class ResNeXt_Block(nn.Module):
    """
    ResNeXt block with group convolutions
    """

    def __init__(self, in_chnls, cardinality, group_depth, stride):
        super(ResNeXt_Block, self).__init__()
        self.group_chnls = cardinality * group_depth
        self.conv1 = BN_Conv2d(in_chnls, self.group_chnls, 1, stride=1, padding=0)
        self.conv2 = BN_Conv2d(self.group_chnls, self.group_chnls, 3, stride=stride, padding=1, groups=cardinality)
        self.conv3 = nn.Conv2d(self.group_chnls, self.group_chnls * 2, 1, stride=1, padding=0)
        self.bn = nn.BatchNorm2d(self.group_chnls * 2)
        self.short_cut = nn.Sequential(
            nn.Conv2d(in_chnls, self.group_chnls * 2, 1, stride, 0, bias=False),
            nn.BatchNorm2d(self.group_chnls * 2)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.bn(self.conv3(out))
        out += self.short_cut(x)
        return F.relu(out)


class ResNeXt(nn.Module):
    """
    ResNeXt builder
    """

    def __init__(self, layers: object, cardinality, group_depth, num_classes) -> object:
        super(ResNeXt, self).__init__()
        self.cardinality = cardinality
        self.channels = 64
        self.conv1 = BN_Conv2d(3, self.channels, 7, stride=2, padding=3)
        d1 = group_depth
        self.conv2 = self.___make_layers(d1, layers[0], stride=1)
        d2 = d1 * 2
        self.conv3 = self.___make_layers(d2, layers[1], stride=2)
        d3 = d2 * 2
        self.conv4 = self.___make_layers(d3, layers[2], stride=2)
        d4 = d3 * 2
        self.conv5 = self.___make_layers(d4, layers[3], stride=2)
        self.fc = nn.Linear(self.channels, num_classes)  # 224x224 input size

    def ___make_layers(self, d, blocks, stride):
        strides = [stride] + [1] * (blocks - 1)
        layers = []
        for stride in strides:
            layers.append(ResNeXt_Block(self.channels, self.cardinality, d, stride))
            self.channels = self.cardinality * d * 2
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = F.max_pool2d(out, 3, 2, 1)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = F.avg_pool2d(out, 7)
        out = out.view(out.size(0), -1)
        out = F.softmax(self.fc(out), dim=1)
        return out

该模型相比DenseNet的区别是,在最后一个denseblock后增加SE_layer。

# SE_layer
self.features.add_module('SE-module', Squeeze_excitation_layer(num_features))

输出

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
         BN_Conv2d-3         [-1, 64, 112, 112]               0
            Conv2d-4          [-1, 128, 56, 56]           8,192
       BatchNorm2d-5          [-1, 128, 56, 56]             256
         BN_Conv2d-6          [-1, 128, 56, 56]               0
            Conv2d-7          [-1, 128, 56, 56]           4,608
       BatchNorm2d-8          [-1, 128, 56, 56]             256
         BN_Conv2d-9          [-1, 128, 56, 56]               0
           Conv2d-10          [-1, 256, 56, 56]          33,024
      BatchNorm2d-11          [-1, 256, 56, 56]             512
           Conv2d-12          [-1, 256, 56, 56]          16,384
      BatchNorm2d-13          [-1, 256, 56, 56]             512
    ResNeXt_Block-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 128, 56, 56]          32,768
      BatchNorm2d-16          [-1, 128, 56, 56]             256
        BN_Conv2d-17          [-1, 128, 56, 56]               0
           Conv2d-18          [-1, 128, 56, 56]           4,608
      BatchNorm2d-19          [-1, 128, 56, 56]             256
        BN_Conv2d-20          [-1, 128, 56, 56]               0
           Conv2d-21          [-1, 256, 56, 56]          33,024
      BatchNorm2d-22          [-1, 256, 56, 56]             512
           Conv2d-23          [-1, 256, 56, 56]          65,536
      BatchNorm2d-24          [-1, 256, 56, 56]             512
    ResNeXt_Block-25          [-1, 256, 56, 56]               0
           Conv2d-26          [-1, 128, 56, 56]          32,768
      BatchNorm2d-27          [-1, 128, 56, 56]             256
        BN_Conv2d-28          [-1, 128, 56, 56]               0
           Conv2d-29          [-1, 128, 56, 56]           4,608
      BatchNorm2d-30          [-1, 128, 56, 56]             256
        BN_Conv2d-31          [-1, 128, 56, 56]               0
           Conv2d-32          [-1, 256, 56, 56]          33,024
      BatchNorm2d-33          [-1, 256, 56, 56]             512
           Conv2d-34          [-1, 256, 56, 56]          65,536
      BatchNorm2d-35          [-1, 256, 56, 56]             512
    ResNeXt_Block-36          [-1, 256, 56, 56]               0
           Conv2d-37          [-1, 256, 56, 56]          65,536
      BatchNorm2d-38          [-1, 256, 56, 56]             512
        BN_Conv2d-39          [-1, 256, 56, 56]               0
           Conv2d-40          [-1, 256, 28, 28]          18,432
      BatchNorm2d-41          [-1, 256, 28, 28]             512
        BN_Conv2d-42          [-1, 256, 28, 28]               0
           Conv2d-43          [-1, 512, 28, 28]         131,584
      BatchNorm2d-44          [-1, 512, 28, 28]           1,024
           Conv2d-45          [-1, 512, 28, 28]         131,072
      BatchNorm2d-46          [-1, 512, 28, 28]           1,024
    ResNeXt_Block-47          [-1, 512, 28, 28]               0
           Conv2d-48          [-1, 256, 28, 28]         131,072
      BatchNorm2d-49          [-1, 256, 28, 28]             512
        BN_Conv2d-50          [-1, 256, 28, 28]               0
           Conv2d-51          [-1, 256, 28, 28]          18,432
      BatchNorm2d-52          [-1, 256, 28, 28]             512
        BN_Conv2d-53          [-1, 256, 28, 28]               0
           Conv2d-54          [-1, 512, 28, 28]         131,584
      BatchNorm2d-55          [-1, 512, 28, 28]           1,024
           Conv2d-56          [-1, 512, 28, 28]         262,144
      BatchNorm2d-57          [-1, 512, 28, 28]           1,024
    ResNeXt_Block-58          [-1, 512, 28, 28]               0
           Conv2d-59          [-1, 256, 28, 28]         131,072
      BatchNorm2d-60          [-1, 256, 28, 28]             512
        BN_Conv2d-61          [-1, 256, 28, 28]               0
           Conv2d-62          [-1, 256, 28, 28]          18,432
      BatchNorm2d-63          [-1, 256, 28, 28]             512
        BN_Conv2d-64          [-1, 256, 28, 28]               0
           Conv2d-65          [-1, 512, 28, 28]         131,584
      BatchNorm2d-66          [-1, 512, 28, 28]           1,024
           Conv2d-67          [-1, 512, 28, 28]         262,144
      BatchNorm2d-68          [-1, 512, 28, 28]           1,024
    ResNeXt_Block-69          [-1, 512, 28, 28]               0
           Conv2d-70          [-1, 256, 28, 28]         131,072
      BatchNorm2d-71          [-1, 256, 28, 28]             512
        BN_Conv2d-72          [-1, 256, 28, 28]               0
           Conv2d-73          [-1, 256, 28, 28]          18,432
      BatchNorm2d-74          [-1, 256, 28, 28]             512
        BN_Conv2d-75          [-1, 256, 28, 28]               0
           Conv2d-76          [-1, 512, 28, 28]         131,584
      BatchNorm2d-77          [-1, 512, 28, 28]           1,024
           Conv2d-78          [-1, 512, 28, 28]         262,144
      BatchNorm2d-79          [-1, 512, 28, 28]           1,024
    ResNeXt_Block-80          [-1, 512, 28, 28]               0
           Conv2d-81          [-1, 512, 28, 28]         262,144
      BatchNorm2d-82          [-1, 512, 28, 28]           1,024
        BN_Conv2d-83          [-1, 512, 28, 28]               0
           Conv2d-84          [-1, 512, 14, 14]          73,728
      BatchNorm2d-85          [-1, 512, 14, 14]           1,024
        BN_Conv2d-86          [-1, 512, 14, 14]               0
           Conv2d-87         [-1, 1024, 14, 14]         525,312
      BatchNorm2d-88         [-1, 1024, 14, 14]           2,048
           Conv2d-89         [-1, 1024, 14, 14]         524,288
      BatchNorm2d-90         [-1, 1024, 14, 14]           2,048
    ResNeXt_Block-91         [-1, 1024, 14, 14]               0
           Conv2d-92          [-1, 512, 14, 14]         524,288
      BatchNorm2d-93          [-1, 512, 14, 14]           1,024
        BN_Conv2d-94          [-1, 512, 14, 14]               0
           Conv2d-95          [-1, 512, 14, 14]          73,728
      BatchNorm2d-96          [-1, 512, 14, 14]           1,024
        BN_Conv2d-97          [-1, 512, 14, 14]               0
           Conv2d-98         [-1, 1024, 14, 14]         525,312
      BatchNorm2d-99         [-1, 1024, 14, 14]           2,048
          Conv2d-100         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-101         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-102         [-1, 1024, 14, 14]               0
          Conv2d-103          [-1, 512, 14, 14]         524,288
     BatchNorm2d-104          [-1, 512, 14, 14]           1,024
       BN_Conv2d-105          [-1, 512, 14, 14]               0
          Conv2d-106          [-1, 512, 14, 14]          73,728
     BatchNorm2d-107          [-1, 512, 14, 14]           1,024
       BN_Conv2d-108          [-1, 512, 14, 14]               0
          Conv2d-109         [-1, 1024, 14, 14]         525,312
     BatchNorm2d-110         [-1, 1024, 14, 14]           2,048
          Conv2d-111         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-112         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-113         [-1, 1024, 14, 14]               0
          Conv2d-114          [-1, 512, 14, 14]         524,288
     BatchNorm2d-115          [-1, 512, 14, 14]           1,024
       BN_Conv2d-116          [-1, 512, 14, 14]               0
          Conv2d-117          [-1, 512, 14, 14]          73,728
     BatchNorm2d-118          [-1, 512, 14, 14]           1,024
       BN_Conv2d-119          [-1, 512, 14, 14]               0
          Conv2d-120         [-1, 1024, 14, 14]         525,312
     BatchNorm2d-121         [-1, 1024, 14, 14]           2,048
          Conv2d-122         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-123         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-124         [-1, 1024, 14, 14]               0
          Conv2d-125          [-1, 512, 14, 14]         524,288
     BatchNorm2d-126          [-1, 512, 14, 14]           1,024
       BN_Conv2d-127          [-1, 512, 14, 14]               0
          Conv2d-128          [-1, 512, 14, 14]          73,728
     BatchNorm2d-129          [-1, 512, 14, 14]           1,024
       BN_Conv2d-130          [-1, 512, 14, 14]               0
          Conv2d-131         [-1, 1024, 14, 14]         525,312
     BatchNorm2d-132         [-1, 1024, 14, 14]           2,048
          Conv2d-133         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-134         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-135         [-1, 1024, 14, 14]               0
          Conv2d-136          [-1, 512, 14, 14]         524,288
     BatchNorm2d-137          [-1, 512, 14, 14]           1,024
       BN_Conv2d-138          [-1, 512, 14, 14]               0
          Conv2d-139          [-1, 512, 14, 14]          73,728
     BatchNorm2d-140          [-1, 512, 14, 14]           1,024
       BN_Conv2d-141          [-1, 512, 14, 14]               0
          Conv2d-142         [-1, 1024, 14, 14]         525,312
     BatchNorm2d-143         [-1, 1024, 14, 14]           2,048
          Conv2d-144         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-145         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-146         [-1, 1024, 14, 14]               0
          Conv2d-147         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-148         [-1,
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值