STANet结果复现及踩坑指南

正文

SATNet网络原文:点我
SATNet网络源码:https://github.com/justchenhao/STANet

坑1:报错/警告坑汇

1、内存爆炸

解决办法:
1、将原始数据集进行裁剪为256,甚至128也可以。
2、将--ds设置为2、4甚至8。
3、将batch size设置小一点

2、警告如下:

C:\ProgramData\anaconda3\envs\pytorch\lib\site-packages\torchvision\transforms\functional.py:132: UserWarning: The 
given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the 
underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or 
make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. 
(Triggered internally at  ..\torch\csrc\utils\tensor_numpy.cpp:143.)
  img = torch.from_numpy(np.array(pic, np.float32, copy=False))

原因:torchvision和pillow不兼容导致的,pillow版本太高。
解决办法:
第一种是先导入InterpolationMode,然后将list_dataset.py、changedetection.py、base_dataset.py中下列行中的method=Image.NEAREST变为InterpolationMode.NEAREST

from torchvision.transforms import InterpolationMode
transform_L = get_transform(self.opt, transform_params, method=Image.NEAREST, 
							normalize=False,test=self.istest)

第二种降低pillow版本,升高torchvision版本。

3、警告如下:

C:\ProgramData\anaconda3\envs\pytorch\lib\site-packages\torchvision\transforms\functional.py:132: UserWarning: The given NumPy array is 
not writeable, and PyTorch does not support non-writeable tensors. This means you can 
write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may 
want to copy the array to protect its data or make it writeable before converting it to a
 tensor. This type of warning will be suppressed for the rest of this program. (Triggered 
 internally at  ..\torch\csrc\utils\tensor_numpy.cpp:143.)
  img = torch.from_numpy(np.array(pic, np.float32, copy=False))

原因:未知,有人猜测是Torchvision中pil_to_tensor时试图创建一个从PIL图像继承底层numpy存储的张量,但该图像被标记为只读,并且由于张量不能是只读,因此必须打印警告(https://discuss.pytorch.org/t/userwarning-the-given-numpy-array-is-not-writeable/78748)。该警告在不同版本中均有出现,尤其是之前读取MNIST时会发出警告(https://github.com/pytorch/vision/pull/4184)。
解决办法:
打开\torchvision\transforms\functional.py:找到第132行,将img = torch.from_numpy(np.array(pic, np.float32, copy=False))中的False改为True

4、报错如下:

dataset [listDataset] was created
OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. 
That is dangerous, since it can degrade performance or cause incorrect results. The best thing to 
do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static 
linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you
 can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, 
 but that may cause crashes or silently produce incorrect results. For more information, please see 
 http://www.intel.com/software/products/support/.

这是在测试list_dataset.py时报的错。前面加上下面两行即可。

import os
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"

5、报错如下:

Traceback (most recent call last):
  File "./train.py", line 171, in <module>
    miou_current = val(opt, model)
  File "./train.py", line 88, in val
    score = model.test(val=True)           # run inference
  File "D:\LYH\ChangeDetection\STANet\models\CDF0_model.py", line 79, in test
    metrics.update(self.L.detach().cpu().numpy(), pred.detach().cpu().numpy())
  File "D:\LYH\ChangeDetection\STANet\util\metrics.py", line 126, in update
    self.confusion_matrix += self.__fast_hist(lt.flatten(), lp.flatten())
  File "D:\LYH\ChangeDetection\STANet\util\metrics.py", line 112, in __fast_hist
    print(len(label_pred[mask]))
IndexError: boolean index did not match indexed array along dimension 0; dimension is 4194304 but corresponding boolean dimension is 65536

我查了很久,也详细看了源GitHub上的issues。有人说是numpy版本太高,降到1.18.0可以使用,实际我降低后并没有卵用,可能还得降,但鉴于配置环境的复杂性,我选择扒拉源码。然后一行一行代码扒,终于发现问题所在。
问题分析:
简要说就是torch版本更新后,有些函数输出的B*C*W*H变成了B*W*H*C
运行报错后,一行一行溯源,发现问题出在predshapeLabelshape不匹配。
CDF0和CDFA中,forward是对backbone的计算的特征图进行相似度计算,然后这个相似度通过阈值1选择后作为pred的结果的。
以下为猜测,没有找到实料。我猜测老版本torch中F.pairwise_distance生成的结果是B*C*W*H,因此可以直接拿来插值然后和label做比较。但新版本应该是变成了B*W*H*C。用默认的Resnet18(即netF)生成的特征层应该为B*64(C)*64(W)*64(H),F.pairwise_distance生成的结果为B*64(W)*64(H)*1(C),插值后就变为B*64*256*256,所以导致报错的dimension前面数值总是后面的64倍。
解决办法:
在CDFA和CDF0中,找到forward函数,将F.pairwise_distance生成的结果进行通道和行列变换。

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.feat_A = self.netF(self.A)  # f(A)
        self.feat_B = self.netF(self.B)   # f(B)
        self.feat_A, self.feat_B = self.netA(self.feat_A,self.feat_B)
        self.dist = F.pairwise_distance(self.feat_A, self.feat_B, keepdim=True)  # 特征距离 B*W*H*C
        /////////////////////////////////
        # 在此新增以下代码行
        self.dist = self.dist.permute(0, 3, 1, 2)  # 需要变换成B*C*W*H
        /////////////////////////////////
        
        self.dist = F.interpolate(self.dist, size=self.A.shape[2:], mode='bilinear',align_corners=True)
        self.pred_L = (self.dist > 1).float()
        # self.pred_L = F.interpolate(self.pred_L, size=self.A.shape[2:], mode='nearest')
        self.pred_L_show = self.pred_L.long()

        return self.pred_L

6、AssertionError: X and Y should be the same shape

pip install visdom==0.1.8.8 pip install dominate

※坑2:论文精度无法复现

  • 提升思路一:
    跑了很多遍STANet算法,进行了各种调参,但PAM最终F1_score基本都在0.73左右,baseMode在0.64左右。即使数据裁剪成256也不行。后来突然灵光一现,是不是因为样本漂移的问题,因为裁剪后,出现了大量的标签为0的样本,导致负样本过大,裁剪后的正样本与总样本比为3107/7120,因此剔除负样本进行训练,PAM F1_score提升到了0.8。
  • 其他思路:
    查看网上有人说ds设置为2效果最好,其他参数和样本不变的情况下,我改为2后精度只提升了0.01左右,此外还有的说裁剪为128,精度也会进一步提升,尝试用128跑了一遍,并没有出现精度提升。

关于精度复现目前还未能实现,希望有复现成功的大佬指导指导,提出建议。

以下是裁剪代码,剔除了全为0的负样本。

# -*- coding: utf-8 -*-
"""
Created on Mon Sep 26 10:51:27 2022

@author: LYH
"""

import os
from skimage import data_dir, io, transform
import numpy as np
from tqdm import tqdm
import time
from numba import njit


def ind2sub(siz, ind):
    '''


    Parameters
    ----------
    siz : List
        DESCRIPTION: 总分块裁剪矩阵
    Ind : Int
        DESCRIPTION:索引

    Returns
    -------
    r, Int, row 索引在分块个数中行号
    c,Int, col 索引在分块个数中列号

    '''
    siz = siz.tolist()
    row = (ind - 1) % siz[0] + 1
    col = int((ind - row) / siz[0] + 1)
    return (row, col)

# @njit
def crop_(in_path, out_path, blockLength=1024, stride=12):
    
    
    
    A_coll = io.ImageCollection(in_path[0])
    B_coll = io.ImageCollection(in_path[1])
    L_coll = io.ImageCollection(in_path[2])
    for i in range(len(A_coll)):
        img_A = A_coll[i]
        img_B = B_coll[i]
        img_L = L_coll[i]
        print(img_A.shape)
        if len(img_A.shape) > 2:
            rw, rh, rc = img_A.shape
        else:
            rw, rh = img_A.shape
            rc = None
        srcSize = [rh, rw]
        block_size = np.ceil(np.divide(srcSize, blockLength)).astype('int')
        blockNum = np.prod(block_size)
        
        try:
            with tqdm(total=blockNum, desc='处理中') as pbar:
                for blockI in range(1, blockNum + 1):
                    [r, c] = ind2sub(block_size, blockI)
                    rmin = max((r - 1) * (blockLength - stride), 0)
                    cmin = max((c - 1) * (blockLength - stride), 0)
                    rmax = min(rmin + blockLength, rh)
                    cmax = min(cmin + blockLength, rw)

                    crop_A = img_A[rmin:rmax, cmin:cmax, :]
                    crop_B = img_B[rmin:rmax, cmin:cmax, :]
                    crop_L = img_L[rmin:rmax, cmin:cmax]
                        
                    if (50 * 255) < np.sum(crop_L):
                        imgname_A = A_coll.files[i].split("\\")[-1][:-4]+"_"+str(blockI)+".png"
                        imgname_B = B_coll.files[i].split("\\")[-1][:-4]+"_"+str(blockI)+".png"
                        imgname_L = L_coll.files[i].split("\\")[-1][:-4]+"_"+str(blockI)+".png"

                        io.imsave(os.path.join(out_path[0], imgname_A), crop_A)
                        io.imsave(os.path.join(out_path[1], imgname_B), crop_B)
                        io.imsave(os.path.join(out_path[2], imgname_L), crop_L)
                    time.sleep(0.005)
                    pbar.update(1)
        except KeyboardInterrupt:
            pbar.close()
            raise
        pbar.close() 

 
def get_path(path):
    level1_path = os.listdir(path)
    
    return level1_path
 
filepath = r"G:\demo2\STANet\LEVIR"
for path in get_path(filepath):
    in_path = []
    out_path = []
    for l2_path in  ['A', 'B', 'label']:
        out_path_temp = os.path.join(r'G:\demo2\STANet\LEVIR256', path, l2_path)
        out_path.append(out_path_temp)
        in_path.append(os.path.join(filepath, path, l2_path, "*.png"))
        if not os.path.exists(out_path_temp):
            os.makedirs(out_path_temp)
    crop_(in_path, out_path, blockLength=256, stride=0)
    # crop_(coll[i], coll.files[i].split("\\")[-1], out_path, blockLength=256, stride=64) # 循环保存图片
  • 6
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 23
    评论
评论 23
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值