【即插即用】STN注意力机制(附源码)

论文地址:

https://arxiv.org/pdf/1506.02025.pdf

简要介绍:

STN,也就是空间变换网络,是一种很酷的机器学习技术,它能自动地调整图像,让网络更好地处理图片的各种变化,比如扭曲或旋转。这个过程就像给图片做一个微整,让它们在变美的同时,也让计算机更容易识别。

STN由三个主要部分组成:定位网络、网格生成器和采样器。定位网络负责分析图像,找出需要调整的地方,比如哪里需要拉伸,哪里需要旋转。想象一下,这就像有一支神奇的笔,能够精确地标记出图片的每个像素需要怎么移动。

网格生成器根据定位网络提供的信息,制作出一个“模板”,这个模板指导采样器如何从原始图像中取样。采样器就像一个巧手的工匠,按照这个模板,用细腻的手法从原始图片中选取正确的像素,然后把它们重新放置到新的位置,让图片达到最佳的视觉效果。

STN的好处是,它能够自我学习如何处理各种复杂的图像变换,这让它在识别图片时更加精准。而且,STN可以像搭乐高一样,和其他深度学习模型一起使用,让整个系统变得更加强大。

结构图:
Pytorch源码:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt


class STN(nn.Module):
    def __init__(self):
        super(STN, self).__init__()
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )
        # 移除fc_loc中第一个nn.Linear的输入尺寸定义
        self.fc_loc = None  # 将在forward中动态创建

        # 用于空间变换网络的权重和偏置初始化参数
        self.fc_loc_output_params = torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)

    def forward(self, x):
        xs = self.localization(x)
        xs_size = xs.size()
        # 动态计算全连接层的输入尺寸
        fc_input_size = xs_size[1] * xs_size[2] * xs_size[3]

        # 根据动态计算的输入尺寸创建fc_loc
        if self.fc_loc is None:
            self.fc_loc = nn.Sequential(
                nn.Linear(fc_input_size, 32),
                nn.ReLU(True),
                nn.Linear(32, 3 * 2)
            )
            # 初始化空间变换网络的权重和偏置
            self.fc_loc[2].weight.data.zero_()
            self.fc_loc[2].bias.data.copy_(self.fc_loc_output_params)

        xs = xs.view(-1, fc_input_size)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)
        grid = F.affine_grid(theta, x.size(), align_corners=True)
        x = F.grid_sample(x, grid, align_corners=True)
        return x


def test_stn(input_size=(1, 1, 32, 32)):
    stn = STN()
    input_tensor = torch.rand(input_size)
    transformed_tensor = stn(input_tensor)
    print(transformed_tensor.shape)

    input_image = input_tensor.numpy()[0][0]
    transformed_image = transformed_tensor.detach().numpy()[0][0]

    plt.figure()
    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(input_image, cmap='gray')
    plt.subplot(1, 2, 2)
    plt.title("Transformed Image")
    plt.imshow(transformed_image, cmap='gray')
    plt.show()


# 测试不同尺寸的输入
# test_stn(input_size=(1, 1, 32, 32))
test_stn(input_size=(1, 1, 64, 64))

  • 14
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值