一个轻量级的ResNet模型封装

import torch
import torch.nn as nn
import torch.nn.functional as F

class myResnet(nn.Module):
    """
    自定义ResNet封装类,接收一个预训练的ResNet模型作为参数,
    并在其基础上进行特征提取,同时返回全局特征和注意力区域特征。
    """

    def __init__(self, resnet):
        """
        初始化方法,将传入的ResNet模型实例赋值给类的成员属性self.resnet。
        """
        super(myResnet, self).__init__()
        self.resnet = resnet

    def forward(self, img, att_size=6):
        """
        前向传播方法,输入img是一个四维张量,代表一批图片(批处理)的数据,
        返回两个输出:全局平均池化特征向量和注意力区域特征。

        参数:
        - img:四维张量,形状为 (batch_size, channels, height, width),代表输入图片数据
        - att_size:整数,默认为6,用于指定自适应平均池化后的特征图尺寸
        
        返回:
        - fc:全局平均池化特征向量,形状为 (batch_size, channels)
        - att:注意力区域特征,形状为 (channels, att_size, att_size)
        """
        # 将输入图片经过ResNet的基础模块处理
        x = img
        x = self.resnet.conv1(x)  # 卷积层
        x = self.resnet.bn1(x)  # 批量归一化层
        x = self.resnet.relu(x)  # 激活层(ReLU)
        x = self.resnet.maxpool(x)  # 最大池化层

        # 通过ResNet的残差模块进行特征提取
        x = self.resnet.layer1(x)  # 第一个残差模块
        x = self.resnet.layer2(x)  # 第二个残差模块
        x = self.resnet.layer3(x)  # 第三个残差模块
        x = self.resnet.layer4(x)  # 第四个残差模块

        # 提取全局平均池化特征向量
        fc = x.mean(dim=3).mean(dim=2).squeeze()  # 计算高度和宽度维度的平均值,得到一维特征向量

        # 提取注意力区域特征
        att = F.adaptive_avg_pool2d(x, output_size=(att_size, att_size)).squeeze().permute(1, 2, 0)
        # 使用自适应平均池化将特征图压缩到指定尺寸,并转换通道至第三维

        # 返回全局特征向量和注意力区域特征
        return fc, att

这段代码实现了一个轻量级的ResNet模型封装类 myResnet。在初始化时,它接收一个预定义的ResNet模型作为参数(如ResNet-18、ResNet-34等),并将该模型挂载为自身的成员属性 self.resnet

在 forward 方法中,输入 img 是一个四维张量,通常代表一批图片(batch)的数据,形状为 (batch_size, channels, height, width)

  1. 首先,将输入图片送入ResNet模型的基本卷积层、批量归一化层、激活层(ReLU)和最大池化层进行初步特征提取。

  2. 然后,通过调用 self.resnet.layer1 至 self.resnet.layer4 分别执行ResNet模型的四个残差模块,逐步抽取更深层次的特征。

  3. 提取全局平均池化特征向量 fc,即将特征图在高度和宽度维度上分别取平均值,得到一个形如 (batch_size, channels) 的向量,可用于分类任务。

  4. 同时,通过自适应平均池化 (F.adaptive_avg_pool2d) 得到一个形如 (batch_size, channels, att_size, att_size) 的特征图 att,该特征图尺寸被压缩为指定的 att_size,可以视为对输入图片的注意力区域特征,通常用于视觉注意力或辅助任务。

最后,myResnet 类的 forward 函数返回两个输出:全局平均池化特征向量 fc 和注意力区域特征 att

  • 6
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值