使用环境
Python 3.10.16
torch 2.1.1+cu118
torchvision 0.16.1+cu118
其它按照官方提供代码的requirements.txt安装
GitHub - WangXueyang-uestc/ARConv: Official repo for Adaptive Rectangular Convolution
数据准备
从官方主页下载pancollection数据集PanCollection for Survey Paper
以WV3 Dataset为例,我们下载训练集和测试集
[1] Training Dataset(训练数据集, 5.76GB): [Baidu Cloud]
[2] Testing Dataset(测试数据集, 20 Examples/per class): [ReducedData(H5 Format)] [FullData(H5 Format)]
训练
在这里我没有使用官方推荐的运行.sh文件,而是直接去调用trainer.py执行,那么我修改了两个文件以找到模型,主要是相对导入和绝对导入的问题。
ARConv/models/models.py
from ARConv import ARConv -> from .ARConv import ARConv
ARConv/trainer.py
from .models import ARNet -> from models import ARNet
运行trainer.py进行训练,下面给出仅使用GPU 0进行训练的代码
CUDA_VISIBLE_DEVICES="0" python trainer.py --batch_size 16 --epochs 600 --lr 0.0006 --ckpt 20 --train_set_path ./pansharpening/training_data/train_wv3.h5 --checkpoint_save_path ./workdir/wv3 --hw_range 1 18 --task 'wv3'
测试
训练完毕后,模型权重pth文件被存入设定的文件目录中,经过作者的回复,和自己的补充,我写了两个python脚本getFullmat.py和getReducedmat.py分别用于生成模型输出的文件,在Matlab中进行测试。将其中的checkpoint_path 改为自己pth存放的文件路径即可。
getFullmat.py
import torch
import torch.nn as nn
import os
import scipy.io as sio
from einops import rearrange
from models import ARNet
import h5py
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def load_set(file_path):
data = h5py.File(file_path)
lms = torch.from_numpy(np.array(data['lms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute(
[1, 0, 2, 3, 4]).float()
ms = torch.from_numpy(np.array(data['ms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute(
[1, 0, 2, 3, 4]).float()
pan = torch.from_numpy(np.array(data['pan'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute(
[1, 0, 2, 3, 4]).float()
return lms, ms, pan
# 路径设置(请根据实际路径修改)
checkpoint_path = r'workdir/wv3/checkpoint_160_2025-05-02-16-06-33.pth'
test_data_path = r'pansharpening/test_data/WV3/test_wv3_OrigScale_multiExm1.h5'
save_dir = r'2_DL_Result/PanCollection/WV3_Full/RRNet/results/'
# 创建保存目录
os.makedirs(save_dir, exist_ok=True)
# 加载模型
model = ARNet().cuda()
model = nn.DataParallel(model)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])
model.eval()
# 加载测试数据
lms, ms, pan = load_set(test_data_path)
# 推理所有图像
with torch.no_grad():
print('Running model inference...')
for i in range(pan.shape[0]):
output = model(pan[i], lms[i], 1000, [1, 18])
output = rearrange(output, 'b c h w -> b h w c') * 2047
output_np = output[0].cpu().numpy()
save_mat_path = os.path.join(save_dir, f'output_mulExm_{i}.mat')
sio.savemat(save_mat_path, {'sr': output_np})
print(f"Saved .mat to {save_mat_path}")
getReducedmat.py
import torch
import torch.nn as nn
import os
import scipy.io as sio
from einops import rearrange
from models import ARNet
import h5py
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def load_set(file_path):
data = h5py.File(file_path)
lms = torch.from_numpy(np.array(data['lms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute(
[1, 0, 2, 3, 4]).float()
ms = torch.from_numpy(np.array(data['ms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute(
[1, 0, 2, 3, 4]).float()
pan = torch.from_numpy(np.array(data['pan'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute(
[1, 0, 2, 3, 4]).float()
return lms, ms, pan
# 路径设置(请根据实际路径修改)
checkpoint_path = r'workdir/wv3/checkpoint_160_2025-05-02-16-06-33.pth'
test_data_path = r'pansharpening/test_data/WV3/test_wv3_multiExm1.h5'
save_dir = r'2_DL_Result/PanCollection/WV3_Reduced/RRNet/results/'
# 创建保存目录
os.makedirs(save_dir, exist_ok=True)
# 加载模型
model = ARNet().cuda()
model = nn.DataParallel(model)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])
model.eval()
# 加载测试数据
lms, ms, pan = load_set(test_data_path)
# 推理所有图像
with torch.no_grad():
print('Running model inference...')
for i in range(pan.shape[0]):
output = model(pan[i], lms[i], 1000, [1, 18])
output = rearrange(output, 'b c h w -> b h w c') * 2047
output_np = output[0].cpu().numpy()
save_mat_path = os.path.join(save_dir, f'output_mulExm_{i}.mat')
sio.savemat(save_mat_path, {'sr': output_np})
print(f"Saved .mat to {save_mat_path}")
之后将要2_DL_Result放入ARConv\MetricCode中
修改 Demo1_Reduced_Resolution_MultiExm_wv3.m 和Demo2_Full_Resolution_multi_wv3.m中的file_test路径,改为存放测试集的文件即可。
我分别修改为了
Demo1_Reduced_Resolution_MultiExm_wv3.m :
opts.file = 'test_wv3_multiExm1';
file_test = strcat('pansharpening/test_data/WV3/', opts.file,'.h5');
Demo2_Full_Resolution_multi_wv3.m:
opts.file = 'test_wv3_OrigScale_multiExm1';
file_test = strcat('pansharpening/test_data/WV3/', opts.file,'.h5');
并在路径中放入了两个测试集文件
test_wv3_multiExm1.h5和test_wv3_OrigScale_multiExm1.h5
之后运行测试即可成功完成测试
600

被折叠的 条评论
为什么被折叠?



