期刊:IEEE Transactions on Circuits and Systems for Video Technology(2022)
代码网址:GitHub - liuzywen/SwinNet
目录
3.2 Two-stream Swin Transformer backbone
3.3 Spatial alignment and channel re-calibration module
3.2 将根目录下的options.py里的路径更改成自己的路径
3.3 上传预训练模型 swin_base_patch4_window12_384_22k.pth
{Errno 2} No such file or directory:'./swin_base_patch4_window12_384_22k.pth'
{Errno 2} No such file or directory:'/root/autodl-tmp/SwinNet/cpts/RGBDSwinTransNet.log'
BUG3:SystemError: tile cannot extend outside image
BUG4:ZeroDivisionError:float division by zero
一、论文阅读笔记
1、摘要
卷积神经网络 (CNN) 擅长提取某些感受野内的上下文特征,而Transformers可以对全局远程依赖特征进行建模。通过吸收变压器的优势和CNN的优点,Swin Transformer具有较强的特征表示能力。在此基础上,我们提出了一种用于 RGB-D 和 RGB-T 显着目标检测的跨模态融合模型 SwinNet。由 Swin Transformer 驱动以提取分层特征,通过注意力机制增强来弥合两种模态之间的差距,并以边缘信息引导以锐化显着对象的轮廓。具体来说,双流 Swin Transformer 编码器首先提取多模态特征,然后提出空间对齐和通道重新校准模块来优化层内跨模态特征。为了澄清模糊边界,边缘引导解码器在边缘特征的指导下实现层间跨模态融合。所提出的模型在 RGB-D 和 RGB-T 数据集上优于最先进的模型,表明它提供了对跨模态互补任务的更多见解。 https://github.com/liuzzywen/SwinNet
2、主要贡献点:
1、提出了一种基于Swin Transformer主干的RGB-D和RGB-T任务的新型SOD模型(SwinNet)。它从 Swin Transformer 主干中提取判别特征,该主干吸收卷积神经网络的局部优势和 Transformer 的远程依赖优点,优于最先进的 (SOTA) RGB-D 和 RGB-T SOD 模型。
2、新设计的空间对齐和通道重新校准模块用于基于注意机制优化每个模态的特征,实现 层跨模态融合从空间和通道方面。
3、该算法在边缘感知模块的引导下实现了层间跨模态融合,生成了更清晰的图像轮廓。
3、方法:
3.1 网络的总体结构图:
3.2 Two-stream Swin Transformer backbone
每个Swin Transformer首先通过块嵌入将输入的单模态图像分割成不重叠的块。颜色流中每个patch的特征被设置为原始像素RGB值的拼接,而深度流的特征被设置为三个复制深度值的拼接。然后,将它们送入多阶段特征变换中。随着网络深度的增加,通过补丁合并层逐渐减少令牌数,得到各模态的层次表示,分别为和。
3.3 Spatial alignment and channel re-calibration module
首先对空间部分的两种模态进行对齐,然后对各自的通道部分进行重新校准,以更加关注每个模态中的突出内容。
3.4 Edge-aware module
众所周知,高层特征表达了更多的语义信息,而浅层特征承载了更多的细节信息。同时,在深度图像中,显著物体更容易呈现弹出结构。用深度对比法很容易描绘出物体的轮廓。因此,利用深度骨干的浅层特征来产生边缘特征。
具体公式:
边缘感知模块输出边缘特征,用于指导模型的解码过程和增强细节。
3.5 Edge-guided decoder
解码器经过空间对准、通道重新校准和边缘特征提取后,将不同模态增强的层次特征与边缘特征结合,得到边缘引导的显著特征。
3.6 Loss function
4、实验
二、代码复现
1、实验细节:
在AutoDL平台上租的服务器,服务器型号为:RTX 3090(24GB),同时使用的相关配置为:PyTorch 1.9.0 Python 3.8(ubuntu20.04) Cuda 11.1
2、数据集:
并未采用论文里的数据集,而是采用rsdds_1500数据集
3、实验步骤:
3.1 将对应的代码和数据集上传到服务器上
3.2 将根目录下的options.py里的路径更改成自己的路径
具体是(更改的地方标红)
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=100, help='epoch number')
parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
parser.add_argument('--batchsize', type=int, default=8, help='training batch size')
parser.add_argument('--trainsize', type=int, default=384, help='training dataset size')
parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
parser.add_argument('--decay_epoch', type=int, default=100, help='every n epochs decay learning rate')
parser.add_argument('--load', type=str, default='./swin_base_patch4_window12_384_22k.pth', help='train from checkpoints')
parser.add_argument('--load_pre', type=str, default='/root/autodl-tmp/SwinNet/SwinTransNet_RGBD_cpts/SwinTransNet_epoch_best.pth', help='train from checkpoints')
parser.add_argument('--gpu_id', type=str, default='0', help='train use gpu')
parser.add_argument('--rgb_root', type=str, default='./rsdds_1500/TrainDataset/RGB/', help='the training rgb images root')
parser.add_argument('--depth_root', type=str, default='./rsdds_1500/TrainDataset/depth/', help='the training depth images root')
parser.add_argument('--gt_root', type=str, default='./rsdds_1500/TrainDataset/GT/', help='the training gt images root')
parser.add_argument('--edge_root', type=str, default='./rsdds_1500/TrainDataset/Edge/', help='the training edge images root')
parser.add_argument('--test_rgb_root', type=str, default='./rsdds_1500/TestDataset/RGB/', help='the test gt images root')
parser.add_argument('--test_depth_root', type=str, default='./rsdds_1500/TestDataset/depth/', help='the test gt images root')
parser.add_argument('--test_gt_root', type=str, default='./rsdds_1500/TestDataset/GT/', help='the test gt images root')
parser.add_argument('--test_edge_root', type=str, default='./rsdds_1500/TestDataset/Edge/', help='the test edge images root')
parser.add_argument('--save_path', type=str, default='./cpts/', help='the path to save models and logs')
3.3 上传预训练模型 swin_base_patch4_window12_384_22k.pth
根据readme提供的网址下载
3.4 生成边界图Edge
1、修改gen_edge.py里面生成边界图原图和生成图的途径
if __name__ == '__main__':
root = r'/root/autodl-tmp/SwinNet/rsdds_1500/TrainDataset'
Edge_Extract(root)
2、运行脚本
python gen_edge.py
3.5 修改对应data.py中的内容
由于我自己的数据集的RGB、GT、depth图片与论文作者使用的公开数据集中对应图片的格式不同,所以需要更改对应的格式。
# dataset for training
# The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
# (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
class SalObjDataset(data.Dataset):
def __init__(self, image_root, gt_root, depth_root, edge_root, trainsize):
self.trainsize = trainsize
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.bmp')]self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
or f.endswith('.png')]self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.tiff')
or f.endswith('.png') or f.endswith('.jpg')]
self.edges = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.bmp')
or f.endswith('.png') or f.endswith('.jpg')]
class test_dataset:
def __init__(self, image_root, gt_root, depth_root, testsize):
self.testsize = testsize
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.bmp')]
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
or f.endswith('.png')]
self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.tiff')
or f.endswith('.png')or f.endswith('.jpg')]
# self.edges = [edge_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
# or f.endswith('.png')or f.endswith('.jpg')]
3.6 安装相应的module
3.7 训练模型
实时运行:
python SwinNet_train.py
后台运行:
nohup python -u SwinNet_train.py >train_newdatasets_output.log 2>&1 &
三、BUGS
BUG1:FileNotFoundError:
{Errno 2} No such file or directory:'./swin_base_patch4_window12_384_22k.pth'
解决方法:
从readme上找到swin_base_patch4_window12_384_22k.pth,把它下载下来放到根目录下就行。
BUG2:FileNotFoundError:
{Errno 2} No such file or directory:'/root/autodl-tmp/SwinNet/cpts/RGBDSwinTransNet.log'
解决办法:跟问题一类似,只要在指出的路径创建文件即可 (即在根目录下创建文件夹cpts,然后再此文件夹下再创建文件RGBDSwinTransNet.log)
BUG3:SystemError: tile cannot extend outside image
解决办法: 由于data.py中的剪裁函数剪裁的区域出现问题,修改对应的代码段即可。
修改前:
def randomCrop(image, label, depth, edge): border = 30 image_width = image.size[0] image_height = image.size[1] crop_win_width = np.random.randint(image_width - border, image_width) crop_win_height = np.random.randint(image_height - border, image_height) random_region = ( (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, (image_height + crop_win_height) >> 1) return image.crop(random_region), label.crop(random_region), depth.crop(random_region), edge.crop(random_region)
修改后:
def randomCrop(image, label, depth, edge):
border = 30
image_width = image.size[0]
image_height = image.size[1]
crop_win_width = np.random.randint(0, max(1, image_width - border))
crop_win_height = np.random.randint(0, max(1,image_height- border))# 生成裁剪区域的左上角和右下角坐标
left = np.random.randint(0, max(0, image_width - crop_win_width))
top = np.random.randint(0, max(0, image_height - crop_win_height))
right = min(image_width, left + crop_win_width)
bottom = min(image_height, top + crop_win_height)random_region = (left, top, right, bottom)
return image.crop(random_region), label.crop(random_region), depth.crop(random_region), edge.crop(random_region)
BUG4:ZeroDivisionError:float division by zero
解决办法:找到对应文件夹中的ImgeStat.py,将对应的v.paaend(self.sum[i]/self.count[i])修改为v.paaend(self.sum[i]/(self.count[i]+0.01))(这个方法并没有深入搞懂到底为什么会出现这样的错误,只是加上一个小数是的除数不为0,跑通代码。所以如果有更好的办法希望大佬能够不吝赐教。)此方法的参考博客:已解决 ZeroDivisionError: float division by zero 。-CSDN博客