STANet 基于时空自注意力的遥感图像变化检测模型
检测数据集LEVIR-CD
环境配置
//要求:
windows or Linux
Python 3.6+
CPU or NVIDIA GPU
CUDA 9.0+
PyTorch > 1.0
visdom==0.1.8.1
dominate
- 我们采用
Linux
+cuda11.0
+python=3.8
,相关安装命令如下:
//创建Python虚拟环境
conda create -n STANet_3.8 python==3.8
//激活环境
conda activate STANet_3.8
//安装pytorch环境
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
//查看环境安装情况
pip install ipython
#(STANet_3.8) :~$ ipython
#Python 3.8.0 (default, Nov 6 2019, 21:49:08)
#Type 'copyright', 'credits' or 'license' for more information
#IPython 7.21.0 -- An enhanced Interactive Python. Type '?' for help.
#In [1]: import torch
#In [2]: torch.cuda.is_available()
#Out[2]: True
//出现以上内容表明环境已经安装好啦。
//安装visdom
pip install visdom==0.1.8.1
//安装dominate
pip install dominate
我们的环境应该是配置好啦,接着下载数据集让我们准备训练吧!
代码数据集下载
代码下载链接:https://github.com/justchenhao/STANet
数据集下载:https://justchenhao.github.io/LEVIR/
网络结构图:
baseline运行代码
#一共200个epoch。我们对前100个epoch保持相同的学习速率,并在剩余的100个epoch中将其线性衰减为0。
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot ./LEVIR-CD/train --val_dataroot ./LEVIR-CD/val --name LEVIR-CDF0 --lr 0.001 --model CDF0 --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop
#基础训练 val.py内容
if __name__ == '__main__':
opt = TestOptions().parse() # get training options
opt = make_val_opt(opt)
opt.phase = 'val'
opt.dataroot = './LEVIR-CD/test'
opt.dataset_mode = 'changedetection'
opt.n_class = 2
#opt.SA_mode = 'PAM'
opt.arch = 'mynet3'
opt.model = 'CDF0'
opt.name = 'LEVIR-CDF0'
opt.results_dir = './results/'
opt.epoch = '160_F1_1_0.77180' //训练后的权重文件
opt.num_test = np.inf
BAM运行代码
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot ./LEVIR-CD/train --val_dataroot ./LEVIR-CD/val --name LEVIR-CDFA0 --lr 0.001 --model CDFA --SA_mode BAM --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop
PAM运行代码
#增加金字塔时空注意力模块PAM
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot ./LEVIR-CD/train --val_dataroot ./LEVIR-CD/val --name LEVIR-CDFAp0 --lr 0.001 --model CDFA --SA_mode PAM --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop
测试
# 按照下面的例子修改val.py即可
if __name__ == '__main__':
opt = TestOptions().parse() # get training options
opt = make_val_opt(opt)
opt.phase = 'test'
opt.dataroot = 'path-to-LEVIR-CD-test' # data root
opt.dataset_mode = 'changedetection'
opt.n_class = 2
opt.SA_mode = 'PAM' # BAM | PAM
opt.arch = 'mynet3'
opt.model = 'CDFA' # model type
opt.name = 'LEVIR-CDFAp0' # project name
opt.results_dir = './results/' # save predicted images
opt.epoch = 'best-epoch-in-val' # which epoch to test
opt.num_test = np.inf
val(opt)