基于MFT的遥感图像分类任务(代码解读)

前前言:自己写的一个遥感图像分类任务训练,协助师兄师姐进行对比算法的编译,还蛮有成就感的。

前言:

  1. Data_Process.py:包含了用于处理和准备训练数据的类。MyData 类和 MyData1 类用于创建数据集对象,这些对象能够根据提供的标签和位置信息从原始图像中裁剪出小块图像,并将其转换为模型训练所需的格式。

  2. MyMFT.py:定义了模型的架构,包括特征提取、注意力机制和分类头。模型设计用于处理两种类型的输入数据:多光谱图像和PAN图像,并输出分类结果。

  3. Train.py:包含了训练循环、模型测试和评估的代码。它首先加载和预处理MS4和PAN图像数据,然后创建训练和测试数据加载器。接着,它实例化 MFT 模型,定义损失函数和优化器,并执行训练循环。训练完成后,它还进行模型测试,保存模型权重,并使用模型对所有数据进行上色,最后计算混淆矩阵和分类性能指标。

代码解读:

1.Data_Process.py

        代码定义了两个 PyTorch Dataset 类,它们用于封装数据准备和获取机制,以便在模型训练和评估过程中使用。以下是对这两个类的具体解读:

MyData 类是为了处理带有标签的训练数据集。

class MyData(Dataset):
    def __init__(self, MS4, Pan, Label, xy, cut_size):
        self.train_data1 = MS4 #MS4: 多光谱图像数据,一个包含多光谱通道的三维数组。
        self.train_data2 = Pan #Pan: PAN 图像数据,通常 PAN 图像的空间分辨率是多光谱图像的四倍。
        self.train_labels = Label #Label: 对应的标签数据,一个包含分类标签的数组。
        self.gt_xy = xy #得到一个列表,包含图像小块的中心坐标(x, y)。
        self.cut_ms_size = cut_size  #定义裁剪图像块的大小。
        self.cut_pan_ms_size = cut_size * 4  #该方法还计算了 PAN 图像的裁剪尺寸,它是多光谱图像裁剪尺寸的四倍。

    def __getitem__(self, index):
        # 同样用index裁剪I与HS,计算 PAN 图像的左上角坐标(x_pan,y_pan),这是通过将多光谱图像的坐标乘以 4 得到的。
        x_ms, y_ms = self.gt_xy[index]
        x_pan = int(4 * x_ms)  # 计算不可以在切片过程中进行
        y_pan = int(4 * y_ms)
        #从MS4和Pan中裁剪出对应于索引index的图像块。
        image_ms = self.train_data1[:, x_ms:x_ms + self.cut_ms_size, y_ms:y_ms + self.cut_ms_size]
        image_pan = self.train_data2[:, x_pan:x_pan + self.cut_pan_ms_size, y_pan:y_pan + self.cut_pan_ms_size]
        locate_xy = self.gt_xy[index]
        target = self.train_labels[index]
        #返回裁剪后的多光谱图像块、PAN 图像块、对应的标签和坐标信息。
        return image_ms, image_pan, target, locate_xy

    def __len__(self):
        #回数据集中样本的总数,等于 gt_xy 列表的长度。
        return len(self.gt_xy)

MyData1 类与 MyData 类似,但它不包含标签信息,用于无监督学习或模型评估阶段。

class MyData1(Dataset):
    def __init__(self, MS4, Pan, xy, cut_size):
        self.train_data1 = MS4 # 多光谱图像数据
        self.train_data2 = Pan # PAN 图像数据
        self.gt_xy = xy # 存储图像小块的中心坐标 (x, y)
        self.cut_ms_size = cut_size # 多光谱图像裁剪块的大小
        self.cut_pan_size = cut_size * 4 # PAN 图像裁剪块的大小,是多光谱图像裁剪块大小的四倍

    def __getitem__(self, index):
        x_ms, y_ms = self.gt_xy[index] # 获取索引对应的坐标点
        # 计算 PAN 图像裁剪的起始 x,y 坐标,是多光谱图像的四倍
        x_pan = int(4 * x_ms)  # 计算不可以在切片过程中进行
        y_pan = int(4 * y_ms)
        # 根据计算的坐标和裁剪大小,从图像数据中裁剪图像块
        image_ms = self.train_data1[:, x_ms:x_ms + self.cut_ms_size,
                   y_ms:y_ms + self.cut_ms_size]
        image_pan = self.train_data2[:, x_pan:x_pan + self.cut_pan_size,
                    y_pan:y_pan + self.cut_pan_size]
        locate_xy = self.gt_xy[index]  # 存储当前样本的中心坐标
        # 返回裁剪后的多光谱图像块、PAN 图像块和中心坐标
        return image_ms, image_pan, locate_xy

    def __len__(self):
        # 数据集的样本总数等于存储的中心坐标列表的长度
        return len(self.gt_xy)

2.MyMFT.py

        该部分定义了一个名为 MFT 的深度学习模型,该模型利用多头注意力机制和Transformer编码器来处理多光谱图像和光达(LiDAR)数据。

首先导入了PyTorch库中的一些基础模块,包括神经网络模块(torch.nn)、函数库(torch.nn.functional)、层归一化(LayerNorm)、线性层(Linear)、Dropout(Dropout)和Softmax(Softmax)。此外,还导入了copy模块用于复制对象,以及einops库中的rearrangerepeat函数,这些函数用于张量的维度变换。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm, Linear, Dropout, Softmax
import copy
from einops import rearrange, repeat
patchsize = 16

INF 函数创建了一个无穷大的矩阵,用于在注意力机制中作为掩码,以屏蔽不需要关注的部分。

def INF(B, H, W):
    return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H), 0).unsqueeze(0).repeat(B * W, 1, 1)

HetConv 是一个自定义的卷积模块,它结合了分组卷积(nn.Conv2dgroups 参数)和逐点卷积(nn.Conv2dkernel_size=1)。逐点卷积通常用于调整通道数,而分组卷积有助于提取特征。

class HetConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, bias=None, p=64, g=64):
        super(HetConv, self).__init__()
        # 计算分组数,确保输入通道数能够被分组数整除
        groups_g = min(g, in_channels)  # 分组数不超过输入通道数
        groups_p = min(p, in_channels)  # 分组数不超过输入通道数
        self.gwc = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, groups=4, padding=kernel_size // 3,
                             stride=stride)
        # Pointwise Convolution
        self.pwc = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=4, stride=stride)
        self.out_channels = out_channels

    def forward(self, x):
        return self.gwc(x) + self.pwc(x)

MCrossAttention 是一个多头自注意力机制的实现。它使用线性层生成查询(Q)、键(K)和值(V),然后计算注意力权重,并将这些权重应用到值上以获取输出。

class MCrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.1, proj_drop=0.1):
  • 18
    点赞
  • 51
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值