Pytorch实现遥感图像场景分类

问题分析

遥感图像的场景分类属于一个多分类问题,毕竟不可能只有两个场景,数据集可以直接获取,pytorch提供了一些图像分类相关的模型如ResNet、VGG、Inception等网络,可以直接获取,当然也可以自己设计,此处我们直接使用ResNet50版本,需要注意的是,需要针对使用的数据集的具体分类数调整ResNet50的num_classes参数,即控制输出通道数以匹配自己使用的数据集的类别数。

日志配置

日志系统是一个底层配置,它需要贯穿于数据处理、模型训练、验证的诸多过程当中,用于记录数据处理操作、模型训练误差变化、验证误差变化以及学习率等指标的变化情况,后续我们需要根据这些数据的变化情况查找问题或进行调优。比较常用的是输出到一个日志文件中(默认日志配置),或使用TensorBoard,相比于TensorBoard,个人更喜欢使用Wandb记录数据的变化情况,Wandb配置以及记录数据非常方便,并且提供网页可以实时监测数据变化情况,本文联合使用默认日志配置以及Wandb日志配置。

默认日志配置

默认日志配置将日志信息保存到一个log或txt文件当中,可以避免日志丢失问题,具体配置代码如下:

import logging
import os
import sys

__all__ = ['setup_logger']

logger_initialized = []


def setup_logger(name='default', output='output.log'):
    logger = logging.getLogger(name)
    if logger in logger_initialized:
        return logger
    logger.setLevel(logging.INFO)
    logger.propagate = False
    formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S")
    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    if output is not None:
        if output.endswith('.txt') or output.endswith('.log'):
            filename = output
        else:
            filename = os.path.join(output, 'log.txt')
        if not os.path.exists(os.path.dirname(filename)):
            os.makedirs(os.path.dirname(filename))
        fh = logging.FileHandler(filename, mode='a')
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(logging.Formatter())
        logger.addHandler(fh)
    logger_initialized.append(logger)
    return logger

Wandb日志配置

Wandb日志的优点是简单、直观,首先需要注册一个wandb账号(也可以匿名登录,但注册一个也不会很麻烦,而且便于管理)。具体配置代码如下:

import wandb

# 初始化Wandb日志
# wandb_key需要根据自己的wandb账号配置
wandb_key = ''
wandb.login(key=wandb_key)
wandb.init(project="landuse-scene-classification",
           # config参数可以根据自己的需要进行记录,不会影响训练
           config={
               "batch_size": batch_size,
               "epochs": epochs,
               "lr": lr,
               "optimizer": "Adam",
               "lr_scheduler": "StepLR",
               "lr_scheduler_step_size": 10,
               "lr_scheduler_gamma": 0.1,
               "optimizer_lr": lr,
               "optimizer_weight_decay": 0.0,
               "optimizer_momentum": 0.9,
               "optimizer_nesterov": False,
               "optimizer_amsgrad":False,
               "loss": "CrossEntropyLoss"
           })

数据准备

本文使用Land use数据集,该数据集共包含21个类别,图像的分辨率为256×256,数据集已经划分好,因此可以略过数据划分流程。

  1. 首先,我们需要创建一个数据集对象,并实现__init__()__len__()以及__getitem__()方法
  2. 创建一个属性或者方法保存数据集的各个类别的类名称,因为在预测时预测的都是类别的id,id并不能直接反应它到底是哪个类别,即你能看懂但别人不知道,即需要返回的是一个类别名称才能够让别人看懂,另外可视化时也需要知道类别名称
import os.path
from glob import glob

import cv2
import pandas as pd
from torch.utils.data import Dataset


class LandUseDataset(Dataset):
    """
    LandUse场景分类数据集
    数据集链接:https://www.kaggle.com/datasets/apollo2506/landuse-scene-classification
    """

    CLASS_NAMES = ['agricultural',
                   'airplane',
                   'baseballdiamond',
                   'beach',
                   'buildings',
                   'chaparral',
                   'denseresidential',
                   'forest',
                   'freeway',
                   'golfcourse',
                   'harbor',
                   'intersection',
                   'mediumresidential',
                   'mobilehomepark',
                   'overpass',
                   'parkinglot',
                   'river',
                   'runway',
                   'sparseresidential',
                   'storagetanks',
                   'tenniscourt']

    def __init__(self, img_dir, ann_file, transforms=None):
        self.img_dir = img_dir
        self.img_list = []
        # 读取图像列表
        self._get_img_list(img_dir)
        # 读取标签数据,标签保存在CSV文件中,使用pandas库读取
        self.ann_data = pd.read_csv(ann_file)
        # 数据增强变换
        self.transforms = transforms

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, index):
        img_p = self.img_list[index]
        # 获取图像的名称,用于更具名称获取对应的标签
        img_name = img_p.split(self.img_dir + '/')[-1]
        # 使用opencv读取图像,可以直接使用PIL库
        img_arr = cv2.imread(img_p)
        # opencv读取图像默认通道顺序为BGR,需要将其转换为RGB
        img_arr = cv2.cvtColor(img_arr, cv2.COLOR_BGR2RGB)
        ann = self.ann_data[img_name]
        ann_id = ann['Label']
        if self.transforms is not None:
            img_arr = self.transforms(img_arr)
        return img_arr, ann_id

    def _get_img_list(self, img_dir):
        # 使用递归的方式读取所有的图像
        img_path = glob(img_dir + '/*')
        for p in img_path:
            if os.path.isdir(p):
                self._get_img_list(p)
            else:
                self.img_list.append(p)

    @staticmethod
    def cid2cname(cid):
        # 用于获取类别id所对应的类别名称
        return LandUseDataset.CLASS_NAMES[cid]

数据划分

略过

数据增强

训练数据增强

from torchvision.transforms import transforms

train_transform = transforms.Compose([
        # 后续的增强策略需要是一个PIL image,因此需要首先将输入图像转化为PIL形式
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(10),
        # 一方面控制一个batch的图像为相同的大小,否则会报错,另一方面减少计算代价
        transforms.Resize((224, 224)),
        # 对于后面这两个基本算是固定的,必须有,且顺序不能改变
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

验证数据增强

# 验证或者测试时不进行翻转或旋转处理,但后两步需要同训练处理一致,且数据参数不能更改,
# 否则相当于改变了数据的分布情况,于是训练数据与验证数据不是处于相同的分布,模型对于验证集来说
# 不会起作用,或者效果差,变化情况可以自己修改数据,观察结果,但图像缩放对结果不会有什么影响,
# 因为卷积操作具备平移不变性
eval_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

创建数据集以及数据加载器

# batch_size见下方配置
train_dataset = LandUseDataset(img_dir='/kaggle/input/landuse-scene-classification/images_train_test_val/train',
                               ann_file='/kaggle/input/landuse-scene-classification/train.csv',
                               transforms=train_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

eval_dataset = LandUseDataset(img_dir='/kaggle/input/landuse-scene-classification/images_train_test_val/validation',
                              ann_file='/kaggle/input/landuse-scene-classification/validation.csv',
                              transforms=eval_transform)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

训练设置

超参数配置

batch_size = 32
epochs = 100
lr = 0.001
# 用于控制保存的checkpoint的总数
checkpoint_save_num = 5
# 随机数种子
seed = 42
device = 'cuda' if torch.cuda.is_available() else 'cpu'

全局配置

# 设置随机种子
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.random.manual_seed(seed)
torch.cuda.random.manual_seed_all(seed)
# 使用确定性算法
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

模型

from torchvision.models import resnet50

model = resnet50(num_classes=21)
model.to(device)

损失函数

# 分类问题一般使用交叉熵损失函数
loss_fn = torch.nn.CrossEntropyLoss()
loss_fn.to(device)

优化器

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

学习率调度器

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

验证设置

评估指标

from sklearn.metrics import classification_report
# classification_report方法可以获取分类问题的全面指标
# 使用from sklearn import metrics可能会报错,说找不到metrics模块,
# 原因可能是sklearn下有两个metrics模块,如果上面这样导入还是有问题,需要更新scikit-learn版本

checkpoint

# 根据以下信息确定保存最优模型
best_f1 = 0.
best_model_epoch = None
# 当前已保存的checkpoint数
checkpoint_num = 0
# checkpoint保存的路径
checkpoint_root_path = 'checkpoint'
if not os.path.exists(checkpoint_root_path):
    os.makedirs(checkpoint_root_path)

模型训练

logger = setup_logger()

for epoch in range(epochs):
    # 当前epoch的训练流程
    model.train()
    total_loss = 0.0
    for batch in train_loader:
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device).squeeze(-1).long()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    lr_scheduler.step()
    avg_loss = total_loss / len(train_loader)
    logger.info(f"Epoch {epoch + 1}, Train Loss: {avg_loss}")
    # 当前epoch的验证流程
    model.eval()
    with torch.no_grad():
        eval_total_loss = 0.0
        predict_list = None
        label_list = None
        for eval_batch in eval_loader:
            eval_inputs, eval_labels = eval_batch
            eval_inputs = eval_inputs.to(device)
            eval_labels = eval_labels.to(device).squeeze().long()
            eval_outputs = model(eval_inputs)
            eval_loss = loss_fn(eval_outputs, eval_labels)
            eval_total_loss += eval_loss.item()
            eval_outputs = torch.argmax(eval_outputs, dim=1)
            if predict_list is None:
                predict_list = eval_outputs
                label_list = eval_labels
            else:
                predict_list = torch.cat((predict_list, eval_outputs), dim=0)
                label_list = torch.cat((label_list, eval_labels), dim=0)
    eval_avg_loss = eval_total_loss / len(eval_loader)
    wandb.log({'train_loss': avg_loss, 'eval_loss': eval_avg_loss})
    logger.info(f"Epoch {epoch + 1}, Eval Loss: {eval_avg_loss}")
    # 获取验证指标,具体怎么使用推荐看方法说明
    metrics_dict = classification_report(label_list.cpu(), predict_list.cpu(), output_dict=True)

    metrics_dict = metrics_dict.get('macro avg')
    recall, precision, f1 = metrics_dict.get('recall'), metrics_dict.get('precision'), metrics_dict.get('f1-score')
    wandb.log({'recall': recall, 'precision': precision, 'f1': f1})
    logger.info('Epoch {}, Recall: {}, Precision: {}, F1: {}'.format(epoch + 1, '%.4f' % recall, '%.4f' % precision,
                                                                   '%.4f' % f1))
    
    # 判断当前epoch训练后的模型是否最优,如果是,则保存并更新最优指标
    best_model_path = checkpoint_root_path + '/best_model.pth'
    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), best_model_path)
        best_model_epoch = epoch + 1
        logger.info('save best model to {}'.format(os.path.abspath(best_model_path)))
    # 保存当前checkpoint,具体保存哪些数据更具自己的需求定制,但一般前四个以及最后一个少不了
    checkpoint = dict(
        model=model.state_dict(),
        optimizer=optimizer.state_dict(),
        lr_scheduler=lr_scheduler.state_dict(),
        epoch=epoch,
        loss=avg_loss,
        recall=recall,
        precision=precision,
        f1=f1
        best_f1=best_f1
    )
    checkpoint_path = checkpoint_root_path + f'/checkpoint_{epoch + 1}.pth'
    torch.save(checkpoint, checkpoint_path)
    logger.info(f"Epoch {epoch + 1}, Model saved to {os.path.abspath(checkpoint_path)}")
    checkpoint_num += 1
    # 如果保存的checkpoint数超过了上限,移除最先开始保存的checkpoint
    if checkpoint_num > checkpoint_save_num:
        remove_path = 'checkpoint/checkpoint_{}.pth'.format(epoch + 1 - checkpoint_save_num)
        os.remove(os.path.abspath(remove_path))
        checkpoint_num -= 1
        logger.info(f"Epoch {epoch + 1}, Model removed from {os.path.abspath(remove_path)}")

if best_model_epoch is not None:
    wandb.summary['best_model_epoch'] = best_model_epoch
    wandb.summary['best_f1'] = best_f1
    logger.info('Best model epoch is {}, best f1 score is {}'.format(best_model_epoch, '%.4f' % best_f1))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值