改进系列(5):在ResNet网络添加SelfAttention自注意力层实现的遥感卫星下的土地利用情况图像分类

目录

1. ResNet介绍

2. SelfAttention 层

3. ResNet34 + SelfAttention

4. 遥感卫星下的土地使用情况分类

4.1 土地使用情况数据集 

4.2 训练

4.3 训练结果

4.4 推理


1. ResNet介绍

ResNet(残差网络)是一种深度卷积神经网络模型,由Kaiming He等人于2015年提出。它的提出解决了深度神经网络的梯度消失和梯度爆炸问题,使得深层网络的训练变得更加容易和有效。

在深度神经网络中,随着网络层数的增加,梯度在反向传播过程中逐渐变小,导致网络的训练变得困难。这是因为在传统的网络结构中,每个网络层都是通过直接逐层堆叠来进行信息的传递。当网络层数增加时,信息的传递路径变得更长,导致梯度逐渐消失。为了解决这个问题,ResNet提出了“残差学习”的概念。

ResNet引入了“残差块”(residual block)的概念,其中每个残差块包含一个跳跃连接(skip connection),将输入直接添加到输出中。这个跳跃连接允许梯度直接通过残差块传递,避免了梯度的消失问题。通过残差块的堆叠,ResNet可以构建非常深的网络,如ResNet-50、ResNet-101等。

ResNet的提出极大地促进了深度神经网络的发展。它在多个视觉任务上取得了非常好的性能,成为了目标检测、图像分类、图像分割等领域的重要基准模型。同时,ResNet的思想也影响了后续的深度神经网络架构设计,被广泛应用于各种深度学习任务中。

2. SelfAttention 层

自注意机制基于Vaswani等人在2017年提出的变压器架构。它计算所有输入单词的嵌入加权和,其中权重由每个单词与序列中其他单词的相关性决定。这些权重是通过嵌入之间的一系列点积运算计算的,然后是一个softmax函数来归一化权重。

与传统的序列模型相比,自注意机制有几个优点。它允许模型更有效地捕获长距离依赖关系,因为序列中任何单词的信息都可以直接影响任何其他单词的表示。它还支持并行计算,因为可以为每个单词独立计算注意力权重。这使得自我关注模型高效且可扩展。

python 实现的代码如下:

# 定义自注意力层
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, height * width)
        energy = torch.bmm(query, key)
        attention = torch.softmax(energy, dim=-1)
        value = self.value_conv(x).view(batch_size, -1, height * width)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, height, width)
        out = self.gamma * out + x
        return out

3. ResNet34 + SelfAttention

这里只对resnet34做了添加,事实上其他版本的resnet网络添加自注意力机制是一样的,只需要把resnet34换成52、101之类的即可

关键代码如下:

添加后的效果如下:

4. 遥感卫星下的土地使用情况分类

下载链接在下面:

Resnet网络改进实战(添加SelfAttention自注意力机制):遥感卫星下的土地利用图像分类资源-CSDN文库

解压后的完整目录如下,data是数据集,runs是训练好的结果

4.1 土地使用情况数据集 

总共有21类别,分别放在不同的目录下,训练集有1470张图片,验证集有630张数据

标签类别如下:

{
    "0": "agricultural",
    "1": "airplane",
    "2": "baseballdiamond",
    "3": "beach",
    "4": "buildings",
    "5": "chaparral",
    "6": "denseresidential",
    "7": "forest",
    "8": "freeway",
    "9": "golfcourse",
    "10": "harbor",
    "11": "intersection",
    "12": "mediumresidential",
    "13": "mobilehomepark",
    "14": "overpass",
    "15": "parkinglot",
    "16": "river",
    "17": "runway",
    "18": "sparseresidential",
    "19": "storagetanks",
    "20": "tenniscourt"
}

可视化结果:

4.2 训练

这里训练了30个epoch,参数如下:

    "train parameters": {
        "model": "resnet34",
        "pretrained": true,
        "freeze_layers": true,
        "batch_size": 8,
        "epochs": 30,
        "optim": "SGD",
        "lr": 0.001,
        "lrf": 0.0001
    },
    "Datasets": {
        "trainSets number": 1470,
        "validSets number": 630
    },
    "model": {
        "total parameters": 21731845.0,
        "train parameters": 621001,
        "flops": 3742463488.0
    },

想要更改训练超参数的可以在train脚本更改

4.3 训练结果

这里最后一轮的指标如下:

    "epoch:29": {
        "train info": {
            "accuracy": 0.9836734693810635,
            "agricultural": {
                "Precision": 1.0,
                "Recall": 0.9857,
                "Specificity": 1.0,
                "F1 score": 0.9928
            },
            "airplane": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "baseballdiamond": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "beach": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "buildings": {
                "Precision": 0.9857,
                "Recall": 0.9857,
                "Specificity": 0.9993,
                "F1 score": 0.9857
            },
            "chaparral": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "denseresidential": {
                "Precision": 0.9286,
                "Recall": 0.9286,
                "Specificity": 0.9964,
                "F1 score": 0.9286
            },
            "forest": {
                "Precision": 0.9722,
                "Recall": 1.0,
                "Specificity": 0.9986,
                "F1 score": 0.9859
            },
            "freeway": {
                "Precision": 0.971,
                "Recall": 0.9571,
                "Specificity": 0.9986,
                "F1 score": 0.964
            },
            "golfcourse": {
                "Precision": 0.9853,
                "Recall": 0.9571,
                "Specificity": 0.9993,
                "F1 score": 0.971
            },
            "harbor": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "intersection": {
                "Precision": 1.0,
                "Recall": 0.9857,
                "Specificity": 1.0,
                "F1 score": 0.9928
            },
            "mediumresidential": {
                "Precision": 0.9559,
                "Recall": 0.9286,
                "Specificity": 0.9979,
                "F1 score": 0.9421
            },
            "mobilehomepark": {
                "Precision": 0.9718,
                "Recall": 0.9857,
                "Specificity": 0.9986,
                "F1 score": 0.9787
            },
            "overpass": {
                "Precision": 0.9577,
                "Recall": 0.9714,
                "Specificity": 0.9979,
                "F1 score": 0.9645
            },
            "parkinglot": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "river": {
                "Precision": 0.9718,
                "Recall": 0.9857,
                "Specificity": 0.9986,
                "F1 score": 0.9787
            },
            "runway": {
                "Precision": 0.9722,
                "Recall": 1.0,
                "Specificity": 0.9986,
                "F1 score": 0.9859
            },
            "sparseresidential": {
                "Precision": 0.9859,
                "Recall": 1.0,
                "Specificity": 0.9993,
                "F1 score": 0.9929
            },
            "storagetanks": {
                "Precision": 1.0,
                "Recall": 0.9857,
                "Specificity": 1.0,
                "F1 score": 0.9928
            },
            "tenniscourt": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "mean precision": 0.9837190476190478,
            "mean recall": 0.9836666666666668,
            "mean specificity": 0.9991952380952381,
            "mean f1 score": 0.9836380952380953
        },
        "valid info": {
            "accuracy": 0.8571428571292516,
            "agricultural": {
                "Precision": 0.8437,
                "Recall": 0.9,
                "Specificity": 0.9917,
                "F1 score": 0.8709
            },
            "airplane": {
                "Precision": 1.0,
                "Recall": 0.9667,
                "Specificity": 1.0,
                "F1 score": 0.9831
            },
            "baseballdiamond": {
                "Precision": 0.8529,
                "Recall": 0.9667,
                "Specificity": 0.9917,
                "F1 score": 0.9062
            },
            "beach": {
                "Precision": 0.7692,
                "Recall": 1.0,
                "Specificity": 0.985,
                "F1 score": 0.8695
            },
            "buildings": {
                "Precision": 0.7714,
                "Recall": 0.9,
                "Specificity": 0.9867,
                "F1 score": 0.8308
            },
            "chaparral": {
                "Precision": 0.9062,
                "Recall": 0.9667,
                "Specificity": 0.995,
                "F1 score": 0.9355
            },
            "denseresidential": {
                "Precision": 0.72,
                "Recall": 0.6,
                "Specificity": 0.9883,
                "F1 score": 0.6545
            },
            "forest": {
                "Precision": 0.8788,
                "Recall": 0.9667,
                "Specificity": 0.9933,
                "F1 score": 0.9207
            },
            "freeway": {
                "Precision": 0.7241,
                "Recall": 0.7,
                "Specificity": 0.9867,
                "F1 score": 0.7118
            },
            "golfcourse": {
                "Precision": 0.8387,
                "Recall": 0.8667,
                "Specificity": 0.9917,
                "F1 score": 0.8525
            },
            "harbor": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "intersection": {
                "Precision": 0.8889,
                "Recall": 0.8,
                "Specificity": 0.995,
                "F1 score": 0.8421
            },
            "mediumresidential": {
                "Precision": 0.8077,
                "Recall": 0.7,
                "Specificity": 0.9917,
                "F1 score": 0.75
            },
            "mobilehomepark": {
                "Precision": 0.8437,
                "Recall": 0.9,
                "Specificity": 0.9917,
                "F1 score": 0.8709
            },
            "overpass": {
                "Precision": 0.6897,
                "Recall": 0.6667,
                "Specificity": 0.985,
                "F1 score": 0.678
            },
            "parkinglot": {
                "Precision": 0.9355,
                "Recall": 0.9667,
                "Specificity": 0.9967,
                "F1 score": 0.9508
            },
            "river": {
                "Precision": 0.9,
                "Recall": 0.6,
                "Specificity": 0.9967,
                "F1 score": 0.72
            },
            "runway": {
                "Precision": 0.8571,
                "Recall": 1.0,
                "Specificity": 0.9917,
                "F1 score": 0.9231
            },
            "sparseresidential": {
                "Precision": 0.9,
                "Recall": 0.9,
                "Specificity": 0.995,
                "F1 score": 0.9
            },
            "storagetanks": {
                "Precision": 0.92,
                "Recall": 0.7667,
                "Specificity": 0.9967,
                "F1 score": 0.8364
            },
            "tenniscourt": {
                "Precision": 1.0,
                "Recall": 0.8667,
                "Specificity": 1.0,
                "F1 score": 0.9286
            },
            "mean precision": 0.8594095238095237,
            "mean recall": 0.857157142857143,
            "mean specificity": 0.9928714285714286,
            "mean f1 score": 0.8540666666666668
        }
    }

曲线图:

混淆矩阵:

 

4.4 推理

推理结果如下:

想要更换数据集训练的话,参考readme文件即可

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

听风吹等浪起

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值