从零开始做图像分类任务(一)——构建图像分类模型(ResNet, VGG, ViT)

序言

本期博文从零使用pytorch构建图像分类任务,按照我个人习惯,从构建图像分类模型数据集处理和Dataset类封装训练和测试等步骤逐一进行代码构建。由于篇幅较大,本期是主要是构造经典的分类模型(ResNet, VGG, Vision Transformer)。其余部分则可以参考:

从零开始做图像分类任务(二)——数据集处理和Dataset类封装
从零开始做图像分类任务(三)——训练和测试脚本(模型保存,断点恢复,Tensorboard,日志输出)

构建好的分类模型,我通常习惯将模型保存在models文件下:
在这里插入图片描述

1. ResNet

对于resnet这个经典网络,很多博客和视频都对其原理进行很详细的介绍,这里就不再赘述,包括之后的模型。对于ResNet论文的模型结构如下:
在这里插入图片描述
这个模型结构表,可以说对模型结构讲得非常清楚了,所以我们可以根据这个表,构造出ResNet类,代码的主要难点在于残差块的写法:
在这里插入图片描述
不同的depth对应不同的残差结构,例如:resnet-34 对应于左边的残差结构,resnet50及以上对应的是右边的残差结构。
完整的代码如下(下面的残差结构的是右边的那一种):

import torch
import torch.nn as nn

class block(nn.Module):
    def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):
        super(block, self).__init__()
        self.expansion = 4
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)
        self.relu = nn.ReLU()
        self.identity_downsample = identity_downsample

    def forward(self, x):
        identity = x
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        x += identity
        x = self.relu(x)
        return x

class ResNet(nn.Module):
    def __init__(self, block, layers, image_channels, num_classes):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # ResNet Layers
        self.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)
        self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)
        self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)
        self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512*4, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        return x


    def _make_layer(self, block, num_residual_blocks, out_channels, stride):
        identity_downsample = None
        layers = []

        if stride != 1 or self.in_channels != out_channels * 4:
            identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels*4, kernel_size=1, stride=stride),
                                                nn.BatchNorm2d(out_channels*4))
        layers.append(block(self.in_channels, out_channels, identity_downsample, stride))
        self.in_channels = out_channels * 4

        for i in range(num_residual_blocks-1):
            layers.append(block(self.in_channels, out_channels))
        return nn.Sequential(*layers)


def ResNet50(img_channels =  3, num_classes = 1000):
    return ResNet(block, [3,4,6,3], img_channels, num_classes)
def ResNet101(img_channels =  3, num_classes = 1000):
    return ResNet(block, [3,4,23,3], img_channels, num_classes)
def ResNet152(img_channels =  3, num_classes = 1000):
    return ResNet(block, [3,8,36,3], img_channels, num_classes)


def main():
    net = ResNet50()
    device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = net.to(device)
    x = torch.randn(2, 3, 224, 224).to(device)
    y = net(x)
    print(y.shape)

if __name__ == "__main__":
    main()

2. VGG16

VGG网络也是一个简单且经典的卷积神经网络,虽然简单,但是对于构建更大更深的卷积网络结构的代码,VGG的写法是具有一定启发意义的。paper中VGG网络的结构如下:
在这里插入图片描述
这个结构虽然不难,只有十几层卷积层,但是如果每一层都要手写的话,这样实现起来的,可读性底且不高效。所以下面展示一种相对高效的写法:

import torch
import torch.nn as nn

# 这里定义好VGG的结构,用一个列表的形式装好,数字代表的是卷积层的输出通道数,’M‘代表的是MaxPool层
VGG16 = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']

class VGGNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(VGGNet, self).__init__()
        self.in_channels = in_channels
        self.conv_layers = self.create_conv_layers(VGG16)

        self.fcs = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes),
        )

        # self.initialize_weights()

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fcs(x)
        return x

    def initialize_weights(self):
     # 初始化权重
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)


    def create_conv_layers(self, architecture):
        layers = []
        in_channels = self.in_channels

        for x in architecture:
            if type(x) == int:
                out_channels = x
                layers += [nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU()]
                in_channels = x
            elif x == 'M':
                layers += [nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))]
        return nn.Sequential(*layers)

if __name__ == "__main__":
    x  = torch.randn(1, 3, 224, 224)
    model = VGGNet(3, 10)
    print(model(x).shape)

3. Vision Transfomer(Vit)

在这里插入图片描述
Vision Transformer 结构包括:图像嵌入(PatchEmbed)多头自注意力Transformer编码器(有16个)ViT总架构。具体实现代码如下:

  1. 图像嵌入(PatchEmbed)
import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, dim=768):
        """
        This class is used to split images to patches
        :param img_size: the size of input images (B, C, H, W)
        :param patch_size:  the size of patchs (x, x), such as x = 16
        :param in_channels: the channels of input image
        :param dim: patches embedding dim = 16 * 16 * 3
        """
        super(PatchEmbed, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = int(img_size // patch_size)**2

        self.proj = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x : (B, C, H, W)
        x = self.proj(x)  # (B, dim, patch_size, patch_size)
        x = x.flatten(2)  # (B, dim, patch_size**2)
        x = x.transpose(1,2)
        return x

if __name__ == "__main__":
    patch_size = 16

    input = torch.randn(1, 3, 224, 224)
    patchembed = PatchEmbed(input.shape[2], patch_size, input.shape[1])
    print(patchembed(input).shape)
    # ==> torch.Size([1, 196, 768])
  1. 多头自注意力
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, n_heads, qkv_bias, attn_p=0., proj_p=0.):
        """
        This is an implement of multi head attention
        :param dim: patches embedding dim = 16 * 16 * 3
        :param n_heads: the number of multi head
        :param qkv_bias: the bias of qkv linear layer
        :param attn_p: attn drop rate
        :param proj_p: projection drop rate
        """
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.dim = dim
        self.scale = (self.head_dim)**-0.5

        self.qkv = nn.Linear(dim, 3*dim, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.attn_drop= nn.Dropout(attn_p)
        self.proj_drop= nn.Dropout(proj_p)

    def forward(self, x):
        # x : (B, N+1, D) B: batch_size, N: the number of patches, D: the dim of patch embedding
        b, n_patches, dim = x.shape
        if dim != self.dim:
            raise ValueError

        qkv = self.qkv(x)   # (B, N+1, 3*D)
        qkv = qkv.reshape(b, n_patches, 3, self.n_heads, self.head_dim)     # (B, N+1, 3, n_heads, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, n_heads, N+1, head_dim)

        q, k ,v = qkv[0], qkv[1], qkv[2]
        k_T = (k.transpose(-2, -1)) * self.scale  # (B, n_heads, head_dim, N+1)
        dp = q @ k_T                # (B, n_heads, N+1, N+1)
        attn = dp.softmax(dim=-1)   # (B, n_heads, N+1, N+1)
        attn = self.attn_drop(attn)

        weighted_avg = attn @ v   # (B, n_heads, N+1, head_dim)
        weighted_avg = weighted_avg.transpose(1, 2) # (B, N+1, n_heads, head_dim)
        weighted_avg = weighted_avg.flatten(2)  # (B, N+1, dim)

        # mlp
        x = self.proj(weighted_avg)
        x = self.proj_drop(x)
        return x

if __name__ == "__main__":
    patches_inputs = torch.randn(1, 197, 768)
    mhsa = MultiHeadAttention(768, 12, qkv_bias=True)
    print(mhsa(patches_inputs).shape)
    # ==> torch.Size([1, 197, 768])
  1. Transformer编码器
class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, p=0.):
        """
        MLP in Transformer Block
        :param dim: patch embedding dimension
        :param mlp_ratio: the hidden layer nodes expansion factor
        :param p: dropout rate
        """
        super(MLP, self).__init__()
        hidden_features = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden_features)
        self.fc2 = nn.Linear(hidden_features, dim)
        self.drop = nn.Dropout(p)
        self.gelu = nn.GELU()

    def forward(self, x):
        x = self.drop(self.gelu(self.fc1(x)))
        x = self.drop(self.fc2(x))
        return x


class Block(nn.Module):
    def __init__(self, dim, n_heads, qkv_bias, mlp_ratio=4.0, attn_p=0., proj_p=0., p=0.):
        """
        The Vision Transfomer Block
        :param dim: patches embedding dim = 16 * 16 * 3
        :param n_heads: the number of multi head
        :param qkv_bias: the bias of qkv linear layer
        :param attn_p: attn drop rate
        :param proj_p: projection drop rate
        :param mlp_ratio: the hidden layer nodes expansion factor
        :param p: dropout rate
        """
        super(Block, self).__init__()
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.mhsa = MultiHeadAttention(dim, n_heads, qkv_bias, attn_p, proj_p)
        self.mlp = MLP(dim, mlp_ratio, p)

    def forward(self, x):
        x = x + self.mhsa(self.norm(x))
        x = x + self.mlp(self.norm(x))
        return x
        
if __name__ == "__main__":
    patches_inputs = torch.randn(1, 197, 768)
    block = Block(768, 12, True)
    print(block(patches_inputs).shape)
    # ==> torch.Size([1, 197, 768])
  1. ViT总架构
class ViT(nn.Module):
    def __init__(self,
                 img_size = 224,
                 patch_size = 16,
                 dim = 768,
                 in_channels = 3,
                 n_classes = 1000,
                 n_heads = 12,
                 depth = 12,
                 qkv_bias = True,
                 mlp_ratio = 4.0,
                 attn_p = 0.,
                 proj_p = 0.,
                 p = 0.
                 ):
        super(ViT, self).__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.position_embed = nn.Parameter(torch.zeros(1, 1+self.patch_embed.n_patches, dim))
        self.pos_drop = nn.Dropout(p)

        self.blocks = nn.ModuleList([
            Block(dim=dim, n_heads=n_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_p=attn_p, proj_p=proj_p, p=p)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.classifier = nn.Linear(dim, n_classes)

    def forward(self, x):
        b = x.shape[0]      # batch_size

        # patch embedding
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(b, -1, -1)
        x = torch.cat((cls_token, x), dim=1)   # (b, 197, dim)
        x += self.position_embed
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)
        x = self.norm(x)

        cls_token_output = x[:, 0]
        x = self.classifier(cls_token_output)
        return x

if __name__ == "__main__":
	input = torch.randn(1, 3, 224, 224)
    vit = ViT()
    print(vit(input).shape)
    # ==> torch.Size([1, 1000])

完整代码见:link

  • 1
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值