用于遥感自监督学习的特征引导掩码自编码器

Feature Guided Masked Autoencoder for Self-supervised Learning in Remote Sensing

2310.18653 (arxiv.org)

Self-supervised learning guided by masked image modelling, such as Masked AutoEncoder (MAE), has attracted wide attention for pretraining vision transformers in remote sensing. However, MAE tends to excessively focus on pixel details, thereby limiting the model’s capacity for semantic understanding, in particular for noisy SAR images. In this paper, we explore spectral and spatial remote sensing image features as improved MAE-reconstruction targets. We first conduct a study on reconstructing various image features, all performing comparably well or better than raw pixels. Based on such observations, we propose Feature Guided Masked Autoencoder (FG-MAE): reconstructing a combination of Histograms of Oriented Graidents (HOG) and Normalized Difference Indices (NDI) for multispectral images, and reconstructing HOG for SAR images. Experimental results on three downstream tasks illustrate the effectiveness of FG-MAE with a particular boost for SAR imagery. Furthermore, we demonstrate the well-inherited scalability of FG-MAE and release a first series of pretrained vision transformers for medium resolution SAR and multispectral images.

</
### 使用UNet框架训练自定义遥感影像数据集 为了有效利用未标记的数据并提升模型性能,在遥感图像分析中可以通过自监督学习预训练模型,随后应用迁移学习来微调特定任务的模型[^1]。 #### 数据准备 确保拥有足够的标注好的遥感图像作为训练集。这些图片应当被切割成适合输入神经网络的小尺寸patch,并且每张图都需要有对应的标签mask表示各个像素所属类别。 #### 环境搭建 安装PyTorch及相关依赖库,创建虚拟环境以隔离不同项目的包管理: ```bash conda create -n unet python=3.8 conda activate unet pip install torch torchvision torchaudio ``` #### 加载与处理数据 编写脚本读取本地存储的遥感影像及其对应掩码文件路径列表;实现Dataset类继承`torch.utils.data.Dataset`接口完成定制化加载器构建工作。 ```python from PIL import Image import os from torch.utils.data import Dataset, DataLoader class CustomRemoteSensingDataset(Dataset): def __init__(self, image_dir, mask_dir, transform=None): self.image_dir = image_dir self.mask_dir = mask_dir self.transform = transform self.images = os.listdir(image_dir) def __len__(self): return len(self.images) def __getitem__(self, index): img_path = os.path.join(self.image_dir, self.images[index]) mask_path = os.path.join( self.mask_dir, self.images[index].replace(".jpg", "_mask.png")) image = np.array(Image.open(img_path).convert("RGB")) mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32) mask[mask == 255.0] = 1.0 if self.transform is not None: augmentations = self.transform(image=image, mask=mask) image = augmentations["image"] mask = augmentations["mask"] return image, mask ``` #### 定义UNet架构 采用已有的成熟实现如 `segmentation_models_pytorch` 或者手动编码 UNet 结构。这里推荐前者因为它提供了更多灵活性以及优化选项。 ```python import segmentation_models_pytorch as smp model = smp.Unet( encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=1, ) ``` #### 训练过程配置 设置损失函数(例如二元交叉熵)、优化算法(AdamW),并通过DataLoader实例化迭代器用于批量获取样本批次。此外还需考虑早停法(Early Stopping)防止过拟合等问题的发生。 ```python loss_fn = nn.BCEWithLogitsLoss() optimizer = optim.AdamW(model.parameters(), lr=1e-4) train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False) for epoch in range(num_epochs): model.train() running_loss = [] for images, masks in train_loader: optimizer.zero_grad() outputs = model(images) loss = loss_fn(outputs, masks.unsqueeze(1)) loss.backward() optimizer.step() running_loss.append(loss.item()) avg_train_loss = sum(running_loss)/len(running_loss) print(f'Epoch {epoch}, Loss: {avg_train_loss}') ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值