问题分析
遥感图像的场景分类属于一个多分类问题,毕竟不可能只有两个场景,数据集可以直接获取,pytorch提供了一些图像分类相关的模型如ResNet、VGG、Inception等网络,可以直接获取,当然也可以自己设计,此处我们直接使用ResNet50版本,需要注意的是,需要针对使用的数据集的具体分类数调整ResNet50的num_classes参数,即控制输出通道数以匹配自己使用的数据集的类别数。
日志配置
日志系统是一个底层配置,它需要贯穿于数据处理、模型训练、验证的诸多过程当中,用于记录数据处理操作、模型训练误差变化、验证误差变化以及学习率等指标的变化情况,后续我们需要根据这些数据的变化情况查找问题或进行调优。比较常用的是输出到一个日志文件中(默认日志配置),或使用TensorBoard,相比于TensorBoard,个人更喜欢使用Wandb记录数据的变化情况,Wandb配置以及记录数据非常方便,并且提供网页可以实时监测数据变化情况,本文联合使用默认日志配置以及Wandb日志配置。
默认日志配置
默认日志配置将日志信息保存到一个log或txt文件当中,可以避免日志丢失问题,具体配置代码如下:
import logging
import os
import sys
__all__ = ['setup_logger']
logger_initialized = []
def setup_logger(name='default', output='output.log'):
logger = logging.getLogger(name)
if logger in logger_initialized:
return logger
logger.setLevel(logging.INFO)
logger.propagate = False
formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S")
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
logger.addHandler(ch)
if output is not None:
if output.endswith('.txt') or output.endswith('.log'):
filename = output
else:
filename = os.path.join(output, 'log.txt')
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
fh = logging.FileHandler(filename, mode='a')
fh.setLevel(logging.DEBUG)
fh.setFormatter(logging.Formatter())
logger.addHandler(fh)
logger_initialized.append(logger)
return logger
Wandb日志配置
Wandb日志的优点是简单、直观,首先需要注册一个wandb账号(也可以匿名登录,但注册一个也不会很麻烦,而且便于管理)。具体配置代码如下:
import wandb
# 初始化Wandb日志
# wandb_key需要根据自己的wandb账号配置
wandb_key = ''
wandb.login(key=wandb_key)
wandb.init(project="landuse-scene-classification",
# config参数可以根据自己的需要进行记录,不会影响训练
config={
"batch_size": batch_size,
"epochs": epochs,
"lr": lr,
"optimizer": "Adam",
"lr_scheduler": "StepLR",
"lr_scheduler_step_size": 10,
"lr_scheduler_gamma": 0.1,
"optimizer_lr": lr,
"optimizer_weight_decay": 0.0,
"optimizer_momentum": 0.9,
"optimizer_nesterov": False,
"optimizer_amsgrad":False,
"loss": "CrossEntropyLoss"
})
数据准备
本文使用Land use数据集,该数据集共包含21个类别,图像的分辨率为256×256,数据集已经划分好,因此可以略过数据划分流程。
- 首先,我们需要创建一个数据集对象,并实现
__init__()
、__len__()
以及__getitem__()
方法 - 创建一个属性或者方法保存数据集的各个类别的类名称,因为在预测时预测的都是类别的id,id并不能直接反应它到底是哪个类别,即你能看懂但别人不知道,即需要返回的是一个类别名称才能够让别人看懂,另外可视化时也需要知道类别名称
import os.path
from glob import glob
import cv2
import pandas as pd
from torch.utils.data import Dataset
class LandUseDataset(Dataset):
"""
LandUse场景分类数据集
数据集链接:https://www.kaggle.com/datasets/apollo2506/landuse-scene-classification
"""
CLASS_NAMES = ['agricultural',
'airplane',
'baseballdiamond',
'beach',
'buildings',
'chaparral',
'denseresidential',
'forest',
'freeway',
'golfcourse',
'harbor',
'intersection',
'mediumresidential',
'mobilehomepark',
'overpass',
'parkinglot',
'river',
'runway',
'sparseresidential',
'storagetanks',
'tenniscourt']
def __init__(self, img_dir, ann_file, transforms=None):
self.img_dir = img_dir
self.img_list = []
# 读取图像列表
self._get_img_list(img_dir)
# 读取标签数据,标签保存在CSV文件中,使用pandas库读取
self.ann_data = pd.read_csv(ann_file)
# 数据增强变换
self.transforms = transforms
def __len__(self):
return len(self.img_list)
def __getitem__(self, index):
img_p = self.img_list[index]
# 获取图像的名称,用于更具名称获取对应的标签
img_name = img_p.split(self.img_dir + '/')[-1]
# 使用opencv读取图像,可以直接使用PIL库
img_arr = cv2.imread(img_p)
# opencv读取图像默认通道顺序为BGR,需要将其转换为RGB
img_arr = cv2.cvtColor(img_arr, cv2.COLOR_BGR2RGB)
ann = self.ann_data[img_name]
ann_id = ann['Label']
if self.transforms is not None:
img_arr = self.transforms(img_arr)
return img_arr, ann_id
def _get_img_list(self, img_dir):
# 使用递归的方式读取所有的图像
img_path = glob(img_dir + '/*')
for p in img_path:
if os.path.isdir(p):
self._get_img_list(p)
else:
self.img_list.append(p)
@staticmethod
def cid2cname(cid):
# 用于获取类别id所对应的类别名称
return LandUseDataset.CLASS_NAMES[cid]
数据划分
略过
数据增强
训练数据增强
from torchvision.transforms import transforms
train_transform = transforms.Compose([
# 后续的增强策略需要是一个PIL image,因此需要首先将输入图像转化为PIL形式
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(10),
# 一方面控制一个batch的图像为相同的大小,否则会报错,另一方面减少计算代价
transforms.Resize((224, 224)),
# 对于后面这两个基本算是固定的,必须有,且顺序不能改变
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
验证数据增强
# 验证或者测试时不进行翻转或旋转处理,但后两步需要同训练处理一致,且数据参数不能更改,
# 否则相当于改变了数据的分布情况,于是训练数据与验证数据不是处于相同的分布,模型对于验证集来说
# 不会起作用,或者效果差,变化情况可以自己修改数据,观察结果,但图像缩放对结果不会有什么影响,
# 因为卷积操作具备平移不变性
eval_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
创建数据集以及数据加载器
# batch_size见下方配置
train_dataset = LandUseDataset(img_dir='/kaggle/input/landuse-scene-classification/images_train_test_val/train',
ann_file='/kaggle/input/landuse-scene-classification/train.csv',
transforms=train_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
eval_dataset = LandUseDataset(img_dir='/kaggle/input/landuse-scene-classification/images_train_test_val/validation',
ann_file='/kaggle/input/landuse-scene-classification/validation.csv',
transforms=eval_transform)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
训练设置
超参数配置
batch_size = 32
epochs = 100
lr = 0.001
# 用于控制保存的checkpoint的总数
checkpoint_save_num = 5
# 随机数种子
seed = 42
device = 'cuda' if torch.cuda.is_available() else 'cpu'
全局配置
# 设置随机种子
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.random.manual_seed(seed)
torch.cuda.random.manual_seed_all(seed)
# 使用确定性算法
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
模型
from torchvision.models import resnet50
model = resnet50(num_classes=21)
model.to(device)
损失函数
# 分类问题一般使用交叉熵损失函数
loss_fn = torch.nn.CrossEntropyLoss()
loss_fn.to(device)
优化器
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
学习率调度器
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
验证设置
评估指标
from sklearn.metrics import classification_report
# classification_report方法可以获取分类问题的全面指标
# 使用from sklearn import metrics可能会报错,说找不到metrics模块,
# 原因可能是sklearn下有两个metrics模块,如果上面这样导入还是有问题,需要更新scikit-learn版本
checkpoint
# 根据以下信息确定保存最优模型
best_f1 = 0.
best_model_epoch = None
# 当前已保存的checkpoint数
checkpoint_num = 0
# checkpoint保存的路径
checkpoint_root_path = 'checkpoint'
if not os.path.exists(checkpoint_root_path):
os.makedirs(checkpoint_root_path)
模型训练
logger = setup_logger()
for epoch in range(epochs):
# 当前epoch的训练流程
model.train()
total_loss = 0.0
for batch in train_loader:
inputs, labels = batch
inputs = inputs.to(device)
labels = labels.to(device).squeeze(-1).long()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
total_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
avg_loss = total_loss / len(train_loader)
logger.info(f"Epoch {epoch + 1}, Train Loss: {avg_loss}")
# 当前epoch的验证流程
model.eval()
with torch.no_grad():
eval_total_loss = 0.0
predict_list = None
label_list = None
for eval_batch in eval_loader:
eval_inputs, eval_labels = eval_batch
eval_inputs = eval_inputs.to(device)
eval_labels = eval_labels.to(device).squeeze().long()
eval_outputs = model(eval_inputs)
eval_loss = loss_fn(eval_outputs, eval_labels)
eval_total_loss += eval_loss.item()
eval_outputs = torch.argmax(eval_outputs, dim=1)
if predict_list is None:
predict_list = eval_outputs
label_list = eval_labels
else:
predict_list = torch.cat((predict_list, eval_outputs), dim=0)
label_list = torch.cat((label_list, eval_labels), dim=0)
eval_avg_loss = eval_total_loss / len(eval_loader)
wandb.log({'train_loss': avg_loss, 'eval_loss': eval_avg_loss})
logger.info(f"Epoch {epoch + 1}, Eval Loss: {eval_avg_loss}")
# 获取验证指标,具体怎么使用推荐看方法说明
metrics_dict = classification_report(label_list.cpu(), predict_list.cpu(), output_dict=True)
metrics_dict = metrics_dict.get('macro avg')
recall, precision, f1 = metrics_dict.get('recall'), metrics_dict.get('precision'), metrics_dict.get('f1-score')
wandb.log({'recall': recall, 'precision': precision, 'f1': f1})
logger.info('Epoch {}, Recall: {}, Precision: {}, F1: {}'.format(epoch + 1, '%.4f' % recall, '%.4f' % precision,
'%.4f' % f1))
# 判断当前epoch训练后的模型是否最优,如果是,则保存并更新最优指标
best_model_path = checkpoint_root_path + '/best_model.pth'
if f1 > best_f1:
best_f1 = f1
torch.save(model.state_dict(), best_model_path)
best_model_epoch = epoch + 1
logger.info('save best model to {}'.format(os.path.abspath(best_model_path)))
# 保存当前checkpoint,具体保存哪些数据更具自己的需求定制,但一般前四个以及最后一个少不了
checkpoint = dict(
model=model.state_dict(),
optimizer=optimizer.state_dict(),
lr_scheduler=lr_scheduler.state_dict(),
epoch=epoch,
loss=avg_loss,
recall=recall,
precision=precision,
f1=f1
best_f1=best_f1
)
checkpoint_path = checkpoint_root_path + f'/checkpoint_{epoch + 1}.pth'
torch.save(checkpoint, checkpoint_path)
logger.info(f"Epoch {epoch + 1}, Model saved to {os.path.abspath(checkpoint_path)}")
checkpoint_num += 1
# 如果保存的checkpoint数超过了上限,移除最先开始保存的checkpoint
if checkpoint_num > checkpoint_save_num:
remove_path = 'checkpoint/checkpoint_{}.pth'.format(epoch + 1 - checkpoint_save_num)
os.remove(os.path.abspath(remove_path))
checkpoint_num -= 1
logger.info(f"Epoch {epoch + 1}, Model removed from {os.path.abspath(remove_path)}")
if best_model_epoch is not None:
wandb.summary['best_model_epoch'] = best_model_epoch
wandb.summary['best_f1'] = best_f1
logger.info('Best model epoch is {}, best f1 score is {}'.format(best_model_epoch, '%.4f' % best_f1))