ARConv的复现流程

使用环境

Python 3.10.16

torch 2.1.1+cu118

torchvision 0.16.1+cu118

其它按照官方提供代码的requirements.txt安装

GitHub - WangXueyang-uestc/ARConv: Official repo for Adaptive Rectangular Convolution

数据准备

从官方主页下载pancollection数据集PanCollection for Survey Paper

以WV3 Dataset为例,我们下载训练集和测试集
[1] Training Dataset(训练数据集, 5.76GB): [Baidu Cloud]
[2] Testing Dataset(测试数据集, 20 Examples/per class): [ReducedData(H5 Format)] [FullData(H5 Format)]
 

训练

在这里我没有使用官方推荐的运行.sh文件,而是直接去调用trainer.py执行,那么我修改了两个文件以找到模型,主要是相对导入和绝对导入的问题。

ARConv/models/models.py

from ARConv import ARConv -> from .ARConv import ARConv

ARConv/trainer.py

from .models import ARNet -> from models import ARNet

 运行trainer.py进行训练,下面给出仅使用GPU 0进行训练的代码

CUDA_VISIBLE_DEVICES="0" python trainer.py --batch_size 16 --epochs 600 --lr 0.0006 --ckpt 20 --train_set_path ./pansharpening/training_data/train_wv3.h5 --checkpoint_save_path ./workdir/wv3 --hw_range 1 18 --task 'wv3'

测试

训练完毕后,模型权重pth文件被存入设定的文件目录中,经过作者的回复,和自己的补充,我写了两个python脚本getFullmat.py和getReducedmat.py分别用于生成模型输出的文件,在Matlab中进行测试。将其中的checkpoint_path 改为自己pth存放的文件路径即可。

getFullmat.py

import torch
import torch.nn as nn
import os
import scipy.io as sio
from einops import rearrange
from models import ARNet
import h5py
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def load_set(file_path):
    data = h5py.File(file_path)
    lms = torch.from_numpy(np.array(data['lms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute(
        [1, 0, 2, 3, 4]).float()
    ms = torch.from_numpy(np.array(data['ms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute(
        [1, 0, 2, 3, 4]).float()
    pan = torch.from_numpy(np.array(data['pan'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute(
        [1, 0, 2, 3, 4]).float()
    return lms, ms, pan

# 路径设置(请根据实际路径修改)
checkpoint_path = r'workdir/wv3/checkpoint_160_2025-05-02-16-06-33.pth'
test_data_path = r'pansharpening/test_data/WV3/test_wv3_OrigScale_multiExm1.h5'
save_dir = r'2_DL_Result/PanCollection/WV3_Full/RRNet/results/'

# 创建保存目录
os.makedirs(save_dir, exist_ok=True)

# 加载模型
model = ARNet().cuda()
model = nn.DataParallel(model)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])
model.eval()

# 加载测试数据
lms, ms, pan = load_set(test_data_path)

# 推理所有图像
with torch.no_grad():
    print('Running model inference...')
    for i in range(pan.shape[0]):
        output = model(pan[i], lms[i], 1000, [1, 18])
        output = rearrange(output, 'b c h w -> b h w c') * 2047
        output_np = output[0].cpu().numpy()
        
        save_mat_path = os.path.join(save_dir, f'output_mulExm_{i}.mat')
        sio.savemat(save_mat_path, {'sr': output_np})
        print(f"Saved .mat to {save_mat_path}")

getReducedmat.py

import torch
import torch.nn as nn
import os
import scipy.io as sio
from einops import rearrange
from models import ARNet
import h5py
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def load_set(file_path):
    data = h5py.File(file_path)
    lms = torch.from_numpy(np.array(data['lms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute(
        [1, 0, 2, 3, 4]).float()
    ms = torch.from_numpy(np.array(data['ms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute(
        [1, 0, 2, 3, 4]).float()
    pan = torch.from_numpy(np.array(data['pan'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute(
        [1, 0, 2, 3, 4]).float()
    return lms, ms, pan

# 路径设置(请根据实际路径修改)
checkpoint_path = r'workdir/wv3/checkpoint_160_2025-05-02-16-06-33.pth'
test_data_path = r'pansharpening/test_data/WV3/test_wv3_multiExm1.h5'
save_dir = r'2_DL_Result/PanCollection/WV3_Reduced/RRNet/results/'

# 创建保存目录
os.makedirs(save_dir, exist_ok=True)

# 加载模型
model = ARNet().cuda()
model = nn.DataParallel(model)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])
model.eval()

# 加载测试数据
lms, ms, pan = load_set(test_data_path)

# 推理所有图像
with torch.no_grad():
    print('Running model inference...')
    for i in range(pan.shape[0]):
        output = model(pan[i], lms[i], 1000, [1, 18])
        output = rearrange(output, 'b c h w -> b h w c') * 2047
        output_np = output[0].cpu().numpy()
        
        save_mat_path = os.path.join(save_dir, f'output_mulExm_{i}.mat')
        sio.savemat(save_mat_path, {'sr': output_np})
        print(f"Saved .mat to {save_mat_path}")

 之后将要2_DL_Result放入ARConv\MetricCode中

修改 Demo1_Reduced_Resolution_MultiExm_wv3.m 和Demo2_Full_Resolution_multi_wv3.m中的file_test路径,改为存放测试集的文件即可。

我分别修改为了

Demo1_Reduced_Resolution_MultiExm_wv3.m :

opts.file = 'test_wv3_multiExm1';
file_test = strcat('pansharpening/test_data/WV3/', opts.file,'.h5');

Demo2_Full_Resolution_multi_wv3.m:

opts.file = 'test_wv3_OrigScale_multiExm1';
file_test = strcat('pansharpening/test_data/WV3/', opts.file,'.h5');

并在路径中放入了两个测试集文件

test_wv3_multiExm1.h5和test_wv3_OrigScale_multiExm1.h5

之后运行测试即可成功完成测试

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值