高光谱图像分类

本文介绍了一种基于PyTorch的高光谱图像分类方法,详细阐述了数据获取、环境配置、网络结构设计、数据集创建、训练过程以及测试和分类结果展示。在网络修改部分,探讨了训练与测试模式的区分以及引入注意力机制来提升图像分类的准确性,通过自注意力模型优化网络表现。
摘要由CSDN通过智能技术生成

代码测试

本文是基于colab平台的代码测试,以下是测试的关键步骤截图

1.首先获得数据并配置相关环境

! wget http://www.ehu.eus/ccwintco/uploads/6/67/Indian_pines_corrected.mat
! wget http://www.ehu.eus/ccwintco/uploads/c/c4/Indian_pines_gt.mat
! pip install spectral

2.导入必要的包

import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score
import spectral
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

网络结构图
在这里插入图片描述
3.定义 HybridSN 类

class_num = 16

class HybridSN(nn.Module):
  def __init__(self):
    super(HybridSN, self).__init__()
    self.conv1 = nn.Conv3d(1, 8, kernel_size=(7,3,3), stride=1, padding=0)
    self.conv2 = nn.Conv3d(8, 16, kernel_size=(5,3,3), stride=1, padding=0)
    self.conv3 = nn.Conv3d(16, 32, kernel_size=(3,3,3), stride=1, padding=0)
    self.conv4 = nn.Conv2d(576, 64, kernel_size=(3,3), stride=1, padding=0)
    self.fc1 = nn.Linear(18496, 256)
    self.fc2 = nn.Linear(256, 128)
    self.fc3 = nn.Linear(128, 16)
    self.dropout = nn.Dropout(0.4)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = self.conv3(x)
    x = F.relu(x)
    x = x.reshape(x.shape[0],-1,19,19)
    x = self.conv4(x)
    x = F.relu(x)
    x = x.flatten(start_dim = 1)
    x = self.fc1(x)
    x = self.dropout(x)
    x = F.relu(x)
    x = self.fc2(x)
    x = self.dropout(x)
    x = F.relu(x)
    x = self.fc3(x)
    return x

4.创建数据集

   首先对高光谱数据实施PCA降维;然后创建 keras 方便处理的数据格式;然后随机抽取 10% 数据做为训练集,剩余的做为测试集。
   首先定义基本函数
 对高光谱数据 X 应用 PCA 变换
def applyPCA(X, numComponents):
    newX = np.reshape(X, (-1, X.shape[2]))
    pca = PCA(n_components=numComponents, whiten=True)
    newX = pca.fit_transform(newX)
    newX = np.reshape(newX, (X.shape[0], X.shape[1], numComponents))
    return newX

# 对单个像素周围提取 patch 时,边缘像素就无法取了,因此,给这部分像素进行 padding 操作
def padWithZeros(X, margin=2
  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值