SwinNet: Swin Transformer Drives Edge-Aware RGB-D and RGB-T Salient Object Detection

期刊:IEEE Transactions on Circuits and Systems for Video Technology(2022)

代码网址:GitHub - liuzywen/SwinNet

目录

一、论文阅读笔记

1、摘要

2、主要贡献点:

3、方法:

3.1 网络的总体结构图:

 3.2 Two-stream Swin Transformer backbone

3.3 Spatial alignment and channel re-calibration module

 3.4 Edge-aware module

3.5 Edge-guided decoder

 3.6 Loss function

4、实验

二、代码复现

1、实验细节:

2、数据集:

3、实验步骤:

3.1 将对应的代码和数据集上传到服务器上

3.2 将根目录下的options.py里的路径更改成自己的路径

3.3 上传预训练模型 swin_base_patch4_window12_384_22k.pth

3.4 生成边界图Edge

3.5 修改对应data.py中的内容

3.6 安装相应的module

3.7 训练模型

三、BUGS

BUG1:FileNotFoundError:

{Errno 2} No such file or directory:'./swin_base_patch4_window12_384_22k.pth'

BUG2:FileNotFoundError:

{Errno 2} No such file or directory:'/root/autodl-tmp/SwinNet/cpts/RGBDSwinTransNet.log'

BUG3:SystemError: tile cannot extend outside image​

BUG4:ZeroDivisionError:float division by zero​


一、论文阅读笔记

1、摘要

      卷积神经网络 (CNN) 擅长提取某些感受野内的上下文特征,而Transformers可以对全局远程依赖特征进行建模。通过吸收变压器的优势和CNN的优点,Swin Transformer具有较强的特征表示能力。在此基础上,我们提出了一种用于 RGB-D 和 RGB-T 显着目标检测的跨模态融合模型 SwinNet。由 Swin Transformer 驱动以提取分层特征,通过注意力机制增强来弥合两种模态之间的差距,并以边缘信息引导以锐化显着对象的轮廓。具体来说,双流 Swin Transformer 编码器首先提取多模态特征,然后提出空间对齐和通道重新校准模块来优化层内跨模态特征。为了澄清模糊边界,边缘引导解码器在边缘特征的指导下实现层间跨模态融合。所提出的模型在 RGB-D 和 RGB-T 数据集上优于最先进的模型,表明它提供了对跨模态互补任务的更多见解。 https://github.com/liuzzywen/SwinNet

2、主要贡献点:

        1、提出了一种基于Swin Transformer主干的RGB-D和RGB-T任务的新型SOD模型(SwinNet)。它从 Swin Transformer 主干中提取判别特征,该主干吸收卷积神经网络的局部优势和 Transformer 的远程依赖优点,优于最先进的 (SOTA) RGB-D 和 RGB-T SOD 模型。

        2、新设计的空间对齐和通道重新校准模块用于基于注意机制优化每个模态的特征,实现 层跨模态融合从空间和通道方面。

        3、该算法在边缘感知模块的引导下实现了层间跨模态融合,生成了更清晰的图像轮廓。

3、方法:

3.1 网络的总体结构图:

 3.2 Two-stream Swin Transformer backbone

        每个Swin Transformer首先通过块嵌入将输入的单模态图像分割成不重叠的块。颜色流中每个patch的特征被设置为原始像素RGB值的拼接,而深度流的特征被设置为三个复制深度值的拼接。然后,将它们送入多阶段特征变换中。随着网络深度的增加,通过补丁合并层逐渐减少令牌数,得到各模态的层次表示,分别为^{​{\left \{ ST_{_{i}}^{c} \right \}_{i=1}}^{4}}^{​{\left \{ ST_{_{i}}^{d} \right \}_{i=1}}^{4}}

3.3 Spatial alignment and channel re-calibration module

        首先对空间部分的两种模态进行对齐,然后对各自的通道部分进行重新校准,以更加关注每个模态中的突出内容。

 3.4 Edge-aware module

         众所周知,高层特征表达了更多的语义信息,而浅层特征承载了更多的细节信息。同时,在深度图像中,显著物体更容易呈现弹出结构。用深度对比法很容易描绘出物体的轮廓。因此,利用深度骨干的浅层特征来产生边缘特征。

 具体公式:

边缘感知模块输出边缘特征F_{e}^{'},用于指导模型的解码过程和增强细节。 

3.5 Edge-guided decoder

        解码器经过空间对准、通道重新校准和边缘特征提取后,将不同模态增强的层次特征与边缘特征结合,得到边缘引导的显著特征。

 3.6 Loss function

4、实验

二、代码复现

1、实验细节:

        在AutoDL平台上租的服务器,服务器型号为:RTX 3090(24GB),同时使用的相关配置为:PyTorch  1.9.0 Python  3.8(ubuntu20.04) Cuda  11.1

2、数据集:

        并未采用论文里的数据集,而是采用rsdds_1500数据集

3、实验步骤:

3.1 将对应的代码和数据集上传到服务器上

3.2 将根目录下的options.py里的路径更改成自己的路径

       具体是(更改的地方标红)

parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=100, help='epoch number')
parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
parser.add_argument('--batchsize', type=int, default=8, help='training batch size')
parser.add_argument('--trainsize', type=int, default=384, help='training dataset size')
parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
parser.add_argument('--decay_epoch', type=int, default=100, help='every n epochs decay learning rate')
parser.add_argument('--load', type=str, default='./swin_base_patch4_window12_384_22k.pth', help='train from checkpoints')
parser.add_argument('--load_pre', type=str, default='/root/autodl-tmp/SwinNet/SwinTransNet_RGBD_cpts/SwinTransNet_epoch_best.pth', help='train from checkpoints')
parser.add_argument('--gpu_id', type=str, default='0', help='train use gpu')
parser.add_argument('--rgb_root', type=str, default='./rsdds_1500/TrainDataset/RGB/', help='the training rgb images root')
parser.add_argument('--depth_root', type=str, default='./rsdds_1500/TrainDataset/depth/', help='the training depth images root')
parser.add_argument('--gt_root', type=str, default='./rsdds_1500/TrainDataset/GT/', help='the training gt images root')
parser.add_argument('--edge_root', type=str, default='./rsdds_1500/TrainDataset/Edge/', help='the training edge images root')
parser.add_argument('--test_rgb_root', type=str, default='./rsdds_1500/TestDataset/RGB/', help='the test gt images root')
parser.add_argument('--test_depth_root', type=str, default='./rsdds_1500/TestDataset/depth/', help='the test gt images root')
parser.add_argument('--test_gt_root', type=str, default='./rsdds_1500/TestDataset/GT/', help='the test gt images root')
parser.add_argument('--test_edge_root', type=str, default='./rsdds_1500/TestDataset/Edge/', help='the test edge images root')
parser.add_argument('--save_path', type=str, default='./cpts/', help='the path to save models and logs')

3.3 上传预训练模型 swin_base_patch4_window12_384_22k.pth

        根据readme提供的网址下载

3.4 生成边界图Edge

1、修改gen_edge.py里面生成边界图原图和生成图的途径

if __name__ == '__main__':
    root = r'/root/autodl-tmp/SwinNet/rsdds_1500/TrainDataset'
    Edge_Extract(root)

2、运行脚本

python gen_edge.py

3.5 修改对应data.py中的内容

        由于我自己的数据集的RGB、GT、depth图片与论文作者使用的公开数据集中对应图片的格式不同,所以需要更改对应的格式。

# dataset for training
# The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
# (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
class SalObjDataset(data.Dataset):
    def __init__(self, image_root, gt_root, depth_root, edge_root, trainsize):
        self.trainsize = trainsize
        self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.bmp')]

        self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
                    or f.endswith('.png')]

        self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.tiff')
                       or f.endswith('.png') or f.endswith('.jpg')]
        self.edges = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.bmp')
                       or f.endswith('.png') or f.endswith('.jpg')]

class test_dataset:
    def __init__(self, image_root, gt_root, depth_root, testsize):
        self.testsize = testsize
        self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.bmp')]
        self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
                    or f.endswith('.png')]
        self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.tiff')
                       or f.endswith('.png')or f.endswith('.jpg')]
        # self.edges = [edge_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
        #                or f.endswith('.png')or f.endswith('.jpg')]

3.6 安装相应的module

3.7 训练模型

实时运行:

python SwinNet_train.py

后台运行:

nohup python -u  SwinNet_train.py >train_newdatasets_output.log 2>&1 &

三、BUGS

BUG1:FileNotFoundError:

{Errno 2} No such file or directory:'./swin_base_patch4_window12_384_22k.pth'

 解决方法:

        从readme上找到swin_base_patch4_window12_384_22k.pth,把它下载下来放到根目录下就行。

BUG2:FileNotFoundError:

{Errno 2} No such file or directory:'/root/autodl-tmp/SwinNet/cpts/RGBDSwinTransNet.log'

解决办法:跟问题一类似,只要在指出的路径创建文件即可 (即在根目录下创建文件夹cpts,然后再此文件夹下再创建文件RGBDSwinTransNet.log) 

BUG3:SystemError: tile cannot extend outside image

解决办法: 由于data.py中的剪裁函数剪裁的区域出现问题,修改对应的代码段即可。

修改前:

def randomCrop(image, label, depth, edge):
    border = 30
    image_width = image.size[0]
    image_height = image.size[1]
    crop_win_width = np.random.randint(image_width - border, image_width)
    crop_win_height = np.random.randint(image_height - border, image_height)
    random_region = (
        (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
        (image_height + crop_win_height) >> 1)
    return image.crop(random_region), label.crop(random_region), depth.crop(random_region), edge.crop(random_region)

修改后:

def randomCrop(image, label, depth, edge):
    border = 30
    image_width = image.size[0]
    image_height = image.size[1]
    crop_win_width = np.random.randint(0, max(1, image_width - border))
    crop_win_height = np.random.randint(0, max(1,image_height- border))

    # 生成裁剪区域的左上角和右下角坐标
    left = np.random.randint(0, max(0, image_width - crop_win_width))
    top = np.random.randint(0, max(0, image_height - crop_win_height))
    right = min(image_width, left + crop_win_width)
    bottom = min(image_height, top + crop_win_height)

    random_region = (left, top, right, bottom)
    return image.crop(random_region), label.crop(random_region), depth.crop(random_region), edge.crop(random_region)

 BUG4:ZeroDivisionError:float division by zero

解决办法:找到对应文件夹中的ImgeStat.py,将对应的v.paaend(self.sum[i]/self.count[i])修改为v.paaend(self.sum[i]/(self.count[i]+0.01))(这个方法并没有深入搞懂到底为什么会出现这样的错误,只是加上一个小数是的除数不为0,跑通代码。所以如果有更好的办法希望大佬能够不吝赐教。)此方法的参考博客:已解决 ZeroDivisionError: float division by zero 。-CSDN博客 

  • 25
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一只懒洋洋

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值