深度学习模型训练河流遥感图像分割数据集_进行像素级别分类,以识别图像中的河流区域

使用深度学习模型训练河流遥感图像分割数据集_进行像素级别分类,以识别图像中的河流区域

河流遥感图像分割数据集,8975张400*400数据集,训练集5385,验证测试1795
在这里插入图片描述

对于河流遥感图像分割任务,使用深度学习模型进行像素级别的分类,以识别图像中的河流区域。在这里插入图片描述

如何准备和处理这些数据,选择合适的模型,以及如何训练和评估这个模型?

以下文字及代码仅供参考。

数据准备

在这里插入图片描述

假设你的数据集结构如下:

river_remote_sensing/
├── images/
│   ├── train/
│   │   ├── img1.png
│   │   ├── img2.png
│   │   └── ...
│   ├── val/
│   │   ├── img1.png
│   │   ├── img2.png
│   │   └── ...
│   └── test/
│       ├── img1.png
│       ├── img2.png
│       └── ...
└── masks/
    ├── train/
    │   ├── mask1.png
    │   ├── mask2.png
    │   └── ...
    ├── val/
    │   ├── mask1.png
    │   ├── mask2.png
    │   └── ...
    └── test/
        ├── mask1.png
        ├── mask2.png
        └── ...

每个mask文件应与对应的image文件名相同,但内容为标签信息,指示每个像素属于河流还是背景。

创建一个配置文件 data_river.yaml 来描述数据集路径和类别信息:

train: ./river_remote_sensing/images/train/
val: ./river_remote_sensing/images/val/
test: ./river_remote_sensing/images/test/

nc: 2  # 类别数量:河流和其他
names: ['background', 'river']

seg: 
  mode: 'custom'
  images_dir: './river_remote_sensing/images/'
  masks_dir: './river_remote_sensing/masks/'

模型选择与训练

考虑到语义分割任务,U-Net、DeepLab系列或Segmentation Models库中提供的其他先进模型都是不错的选择。这里我们以Segmentation Models库为例说明。

安装依赖库

确保安装了必要的库:

pip install segmentation-models-pytorch albumentations opencv-python-headless torch torchvision
训练脚本

编写训练脚本:

import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader
import torch
from torch import nn
from dataset import RiverDataset  # 假设你已实现了一个自定义的数据集类

# 自定义数据集类示例
class RiverDataset:
    def __init__(self, images_dir, masks_dir):
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in os.listdir(images_dir)]
        self.masks_fps = [os.path.join(masks_dir, mask_id) for mask_id in os.listdir(masks_dir)]
    
    def __getitem__(self, i):
        # 读取图片和掩码
        image = cv2.imread(self.images_fps[i])
        mask = cv2.imread(self.masks_fps[i], 0)
        return image, mask
    
    def __len__(self):
        return len(self.images_fps)

ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['river']
ACTIVATION = 'sigmoid'  # 可能根据需要调整

model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

train_dataset = RiverDataset('./river_remote_sensing/images/train/', './river_remote_sensing/masks/train/')
valid_dataset = RiverDataset('./river_remote_sensing/images/val/', './river_remote_sensing/masks/val/')

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=4)

loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device='cuda',
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device='cuda',
    verbose=True,
)

max_score = 0

for i in range(0, 40):  # 训练周期数
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')

    if i == 25:
        optimizer.param_groups[0]['lr'] /= 10
        print('Decrease decoder learning rate to 1e-5!')

模型评估与预测

训练完成后,可以加载最佳模型对新图像进行预测并评估模型性能:

best_model = torch.load('./best_model.pth')

test_dataset = RiverDataset('./river_remote_sensing/images/test/', './river_remote_sensing/masks/test/')
test_dataloader = DataLoader(test_dataset)

# 预测逻辑
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值