DATR论文阅读与实验复现

DATR实验环境配置

首先从GitHub - zhengxuJosh/DATRdownload code,实验环境目测和Trans4PASS一模一样,配置环境列表请看上篇文章Trans4PASS论文阅读与实验复现-CSDN博客的实验环境配置,这里不过多赘述,注意本篇文章还需要一个叫做Segformer的代码进行辅助实验,在本篇文章的Github中没有附带Segformer的代码,Segformer的代码链接GitHub - NVlabs/SegFormer: Official PyTorch implementation of SegFormer,具体该项目的实际训练过程是全部使用了Segformer的代码我还在探索,这里插个眼,将代码跑通后我来补充

数据集准备

Cityscapes dataset

image

Below are examples of the high quality dense pixel annotations from Cityscapes dataset. Overlayed colors encode semantic classes. Note that single instances of traffic participants are annotated individually.

The Cityscapes dataset is availabel at Cityscapes

SynPASS dataset

image

 SynPASS dataset contains 9080 panoramic images (1024x2048) and 22 categories.

The scenes include cloudy, foggy, rainy, sunny, and day-/night-time conditions.

The SynPASS dataset is availabel at Trans4PASS

DensePASS dataset

image

The DensePASS dataset is availabel at Trans4PASS,数据集组织方式如下:

Data Path:

datasets/
|--- cityscapes
|   |___ gtfine
|   |___ leftImg8bit
|--- SynPASS
|   |--- img
|   |   |___ cloud
|   |   |___ fog
|   |   |___ rain
|   |   |___ sun
|   |--- semantic
|   |   |___ cloud
|   |   |___ fog
|   |   |___ rain
|   |   |___ sun
|--- DensePASS
|   |___ gtfine
|   |___ leftImg8bit

注意这里的synPASS数据集是一个合成的数据集,在迁移学习中,这样的合成数据集因为好获得标注,所以在语义分割中时常用作辅助数据集来帮助实验进行,下面开始论文的阅读,数据集我目前对这方面没有研究,所以先搁置一下,具体问一下师兄。

论文阅读+实验复现

Look at the Neighbor: Distortion-aware Unsupervised Domain Adaptation for Panoramic Semantic Segmentation

但看这个标题,和Trans4PASS那篇文章你说没有关联鬼都不信,同样是文学性标题+work的标题,鉴定为标题纯纯的抄袭。但这篇文章我觉得做的最好的一个位置就是在首页就放了这个结果图,他把22年的SOTA超越后,给了一个这个对比图,拉踩是把,左边这个图(a)先不管,一看就是在特征提取的地方提出了一个新的方法,然后比之前的提取效果更好了,套路发文方式,改点小结构提升效果发顶会。

Abstract

看这个Abstract就是说之前使用全景影像语义分割UDA之前的网络结构泛化性差,其实把我看了之前的论文都是base在ERP投影模式下,科普一下ERP投影(维基百科)https://en.wikipedia.org/wiki/Equirectangular_projection,解释的比较详细,具体就有点将球体投影到平面上的一种投影变换方式(本科地图学没白学),然后这种投影方式会有误差产生,这个就很好解释了,球和平面不一样,会有变形,举个例子,你用美颜相机,你从不是很好看->漂亮相当于一次投影,变漂亮相当于你要经过一次变形,所以投影不可避免会产生变形,这篇文章的目的就是别的文章都是建立在ERP的基础上迁移的,但是这个文章卷这个投影的方式,通过自己的方法研究了一个能够较好解决这个ERP投影问题的结构来尽量规避变形引发的特征提取出现的问题。然后他说自己的方法更简单、更易于实现,而且计算效率更高。具体来说,我们提出了失真感知注意(DA),捕捉邻近像素分布,而无需使用任何几何约束。此外,我们提出了一个类别特征聚合(CFA)模块,通过内存库迭代更新特征表示。因此,两个域之间的特征相似性可以持续优化。大量实验表明,我们的方法实现了新的最先进性能,同时显着减少了80%的参数。这几句的真实性有待考证。

Introduction

这个introduction和TransPASS不能说基本一样,只能说除了方法这个位置,基本一毛一样,我觉得这段话有价值的是这个话:我们发现ERP中的像素邻域确实引入了较少的失真。由于ERP打乱了球形像素的等距分布,因此360°图像特定纬度上的任两个像素之间的距离与ERP图像(球形到平面投影)的距离不同。因此,通过减小感受野来捕获像素之间的位置分布更加高效地解决失真问题。因此,控制邻域大小对于平衡感受野和失真问题之间的权衡至关重要。这实际也比较废话,因为你变形后,距离肯定不一样,这里应该是指相对距离。然后我们来看看Related Work。这个图可以参考一下,还挺好看的

Related Work

全景语义分割的UDA可分为三大类:对抗学习、伪标签生成和特征原型适应。第一类方法倾向于通过从图像级别、特征级别和输出级别进行对齐来学习域的不变性。第二类方法为目标域训练生成伪标签,并利用自我训练对其进行改进。第三类方法,例如 Mutual Prototype Adaption (MPA),将特征嵌入与分别在源域和目标域获得的原型进行对齐。然而,这些方法利用多阶段训练策略,因此未能在每个小批量中关联特征。不同的是,我们提出了 CFA 模块来聚合类别原型并迭代更新它们,促使原型具有更全面地表示域的特殊性

ERP的畸变问题。先前的研究通过可变形核和根据球体几何先验设计适应性CNN来缓解畸变问题。特别地,[46]在补丁嵌入期间自适应调整感受野,以更好地保持语义一致性,并在特征解析过程中考虑畸变问题。然而,由于较大的感受野,此方法效率低下。另一种方式是设计畸变感知神经网络。如可变形补丁嵌入 (DPE) 和可变形MLP (DMLP) 等,这些变形组件被广泛应用于全景语义分割,因为它们可以在对输入数据进行补丁化时帮助学习全景特征的先验知识。尽管性能得到了显著提高,但数据泛化能力有限,主要依赖于先前的几何知识。不同的是,我们发现ERP的邻域确实引入了较少的畸变,有利于对像素分布的变异进行泛化。因此,我们提出了 DA 模块来通过更少的参数解决畸变问题。

自注意力 (SA) 被定义为在查询、键和值序列上的点积操作。Dosovitskiy等人首次提出在视觉领域中利用SA对图像补丁进行操作。最近,提出了丰富的注意力范式变体来解决视觉问题。对于全景语义分割,Multi-Head Self-Attention (MHSA) 和 Efficient Self-Attention (ESA) 广泛应用于捕获360°图像的长距离依赖性。然而,MHSA和ESA在缓解由全局特征提取策略引起的像素之间的畸变问题方面是不足的。相比之下,我们的工作以不同的精神聚焦于邻近像素,因此提出了DA模块来通过捕获不同域之间的不同像素分布来减少畸变问题。

感觉这个Related work说的很好了,就是为了解决什么问题提出了什么方法,他的整体思路是相当于,为了解决泛化能力差的问题,改进了DPE和DMLP,然后还能捕获在这个不同域的不同像素分布来减少畸变。

Method

又到了最喜欢的Method部分,首先是分析了ERP的畸变

Theoretical Analysis of ERP Distortion

这里和下图结合一起看,就是意思是上图最下边的两个黄色像素比如说(3,1)和(0,1)在投影到蓝色的像素时大小发生了变化,不符合本身的分布,由不均匀到了均匀产生了变形,因此在提取特征发生了错误不准确的情况,然后基于这个这个原本的位置和投影位置的差的失真系数,他类似于纵向投影均匀分布,所以影像步道,只是横向有误差,感觉这个地方纵向也可以做文章,+个新系数,插个眼,感觉这样可以再改进一下,整个二维失真系数。

Distortion-aware UDA Framework

他在原本的注意力中加了一个位置编码,相当于保存一个偏置来限制误差,公式看不懂斯密达,这部分感觉得结合代码一起来看。下面看看他的网络结构,又到了魔幻网络的部分

发现了重点,模型小且精度高,花小钱办大事,发文章的好方法。

下面更是一个玄学的域自适应的方法:

核心思路最小化分割损失函数,然后再optimize目标域函数,还是对抗的思路,然后这个CFA聚合类别,我懂了,相当于最小化同样的类别的特征中心MSE来达到源域与目标域的学习,感觉这个学习可能是基于相似类别特征中心靠近或者相似的想法,是一个比较好的想法。

话不多说看代码

代码结构

dataset文件夹

一看代码一股亲切之意铺面而来,和那个TransPASS的结构非常相似,下面从上到下按顺序来看看:

首先还是熟悉的list这里边是数据的列表,在训练实际读取数据路径进行实际实验,不赘述,然后在文件夹中,有着以dataset为名的各种py文件,这些函数应该是将原始影像数据通过处理将数据转化为网络可以进行处理的张量形式,相当于数据预处理的工作。

然后我们来看一下city文件夹

这几个.py文件都是对于city_scape进行一系列预处理工作,我觉得应该是他想用这些方法在实际训练之前对于数据进行一系列预处理工作,然后提升效果。

这里有一个eq2tanget函数:

实现了将360度全景图像转换为等角度切分的多个小图像的功能。具体实现步骤如下:

  1. 导入必要的库和模块。
  2. 定义了一个名为pair的函数,用于将输入转换为元组形式,如果输入已经是元组则直接返回,否则将其包装为元组形式。
  3. 定义了一个名为eq2tangent的函数,用于将全景图像转换为等角度切分的小图像。
  4. 函数参数img是输入的全景图像,heightwidth是输出小图像的高度和宽度,默认为224。函数首先对输入的图像进行预处理,将其归一化到0到1之间。
  5. 定义了一些常数和参数,如视场角(FOV)、PI(π)、PI_2(π/2)、PI2(2π)、行数(num_rows)、每行列数(num_cols)、中心点的经度和纬度(phi_centers)、phi_interval(纬度间隔)等。
  6. 根据参数配置,生成全景图像中每个小图像的位置和遮罩。遮罩是一个二维数组,表示每个小图像在全景图像中的位置。
  7. 利用PyTorch的张量和函数操作对输入图像进行处理,并生成目标小图像。具体包括坐标的转换、网格的生成以及利用grid_sample函数对全景图像进行透视变换。
  8. 最后返回转换后的小图像。

这段代码主要利用了PyTorch的张量操作和几何变换函数,将360度全景图像转换为等角度切分的多个小图像,方便后续在深度学习模型中应用。

这步的目的是:(保持空间一致性)

等角度切分全景图像为小图像的目的是为了实现全景图像的球面投影到平面上的转换。在全景图像处理中,通常会将全景图像投影到球面上,并根据特定的视角和方向,从球面上采样获取所需的图像信息。然后,通过将球面上的图像样本转换为平面上的小图像,可以更方便地进行处理和分析,例如用于训练深度学习模型或进行图像编辑等任务。等角度切分可以确保小图像之间的空间角度保持一致,有助于保持全景图像的空间连续性和一致性。等角切分实际上就是前文中提到的P,而且这个记录了小影像的位置,就是为了实现这个图中内容的RPE:

metrics文件夹

这个文件夹中主要是用于计算类似mIou这样的指标,因此也不需要过多研究,具体不理解各个指标可以百度查询,百度中有许多通俗的解释

models文件夹

包含DATR和pvt_v2,这两个文件夹是当前项目的重点,所以我们来仔细研究一下,先看DATR文件夹

import torch
import torch.nn as nn
import torch.nn.functional as F 
from .decoder import SegFormerHead
from .encoder import DATRM, DATRT, DATRS

class DATR(nn.Module):
    def __init__(self, backbone, num_classes=20, embedding_dim=256, pretrained=None):
        super().__init__()
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self.feature_strides = [4, 8, 16, 32]

        if backbone == 'DATRM':
            self.encoder = DATRM()
            if pretrained:
                state_dict = torch.load('/hpc/users/CONNECT/tpan695/DATR/models/ptmodel/mit_b0.pth')
                state_dict.pop('head.weight')
                state_dict.pop('head.bias')
                self.encoder.load_state_dict(state_dict,strict=False)
        ## initilize encoder
        elif backbone == 'DATRT':
            self.encoder = DATRT()
            if pretrained:
                state_dict = torch.load('/hpc/users/CONNECT/tpan695/DATR/models/ptmodel/mit_b1.pth')
                state_dict.pop('head.weight')
                state_dict.pop('head.bias')
                self.encoder.load_state_dict(state_dict,strict=False)
        ## initilize encoder
        elif backbone == 'DATRS':
            self.encoder = DATRS()
            if pretrained:
                state_dict = torch.load('/hpc/users/CONNECT/tpan695/DATR/models/ptmodel/mit_b2.pth')
                state_dict.pop('head.weight')
                state_dict.pop('head.bias')
                self.encoder.load_state_dict(state_dict,strict=False)
        self.in_channels = self.encoder.embed_dims

        self.backbone = backbone
        self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, embedding_dim=self.embedding_dim, num_classes=self.num_classes)
        
        self.classifier = nn.Conv2d(in_channels=self.in_channels[-1], out_channels=self.num_classes, kernel_size=1, bias=False)

    def _forward_cam(self, x):
        
        cam = F.conv2d(x, self.classifier.weight)
        cam = F.relu(cam)
        
        return cam

    def get_param_groups(self):

        param_groups = [[], [], []] # 
        
        for name, param in list(self.encoder.named_parameters()):
            if "norm" in name:
                param_groups[1].append(param)
            else:
                param_groups[0].append(param)

        for param in list(self.decoder.parameters()):

            param_groups[2].append(param)
        
        param_groups[2].append(self.classifier.weight)

        return param_groups

    def forward(self, x):
        _, _, height, width = x.shape

        _x = self.encoder(x)

        feature =  self.decoder(_x)
        pred = F.interpolate(feature, size=(height,width), mode='bilinear', align_corners=False)

        return pred, _x

这段代码定义了一个名为DATR的神经网络模型,用于语义分割任务。让我逐步解释:

  1. 首先,导入了必要的PyTorch库和自定义的模块(decoderencoder),这些模块用于构建模型的解码器和编码器。

  2. DATR类继承自nn.Module,表示这是一个PyTorch模型。

  3. __init__方法中:

    • backbone参数指定了所使用的骨干网络类型,可以是DATRMDATRTDATRS
    • num_classes参数表示最终的输出类别数。
    • embedding_dim参数表示特征嵌入的维度。
    • 如果设置了pretrained参数为True,则会加载预训练的权重。
    • 根据选择的骨干网络类型,初始化对应的编码器(encoder)并加载预训练权重。
    • 初始化解码器(decoder)和分类器(classifier)。
  4. _forward_cam方法用于计算类激活图(CAM),它通过卷积操作将特征图与分类器的权重相乘,并应用ReLU激活函数。

  5. get_param_groups方法用于将模型参数分组,这些参数将用于优化器的不同学习率。

  6. forward方法用于前向传播:

    • 首先,通过编码器将输入x转换为特征表示_x
    • 然后,将特征表示传递给解码器,得到语义分割的预测结果pred
    • 最后,通过双线性插值将预测结果的尺寸调整为输入x的尺寸,并返回预测结果和特征表示。

这个是一个网络的概括流程,具体网络的细节还需要在model中的其他文件中体现。这里pretrained我猜测用了segformer的预训练模型,因为命名一模一样“m1_bit0.pth”和“m2_bit0.pth”,这里的encoder和decoder是这个项目的重点,下面将重点分析

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import math
from natten import NeighborhoodAttention # 报错
from natten import NeighborhoodAttention这个引用是报错的,说明有问题

查阅资料后需要从GitHub - SHI-Labs/NATTEN: Neighborhood Attention Extension. Bringing attention to a neighborhood near you!

配置一下natten,吐槽一句,这个作者非常离谱,不给.yml文件,环境不知道怎么搞的,然后源文件中没有segformer,还得试他的代码。

言归正传,在terminal运行这段命令可以安装:

pip install natten==0.14.4

解决了报错的问题,进入正题,我了一下代码,除了这个有三个类别的网络DATRM,DATRT,DATRS,然后在注意力用了一个neighborhood attention,然后其他和trans4PASS一模一样,所以就不过多解释了,那个neighobrhood attention可以看Neighborhood Attention Transformer - 知乎这个原文的解释,这里边需要注意的是只有在最大的embedded_dims上使用了natten的neighborhood attention。解码器对应了编码器,也没有太大的问题,使用了MLP

在ss_cfa_syn.py中介绍了该代码使用cfa代码进行域自适应的方法:

def IFV(feat_1, feat_2, target0, target1, cen_bank_1, cen_bank_2, epoch):
    #feat_T.detach()
    size_f = (feat_1.shape[2], feat_1.shape[3])
    tar_feat_0 = nn.Upsample(size_f, mode='nearest')(target1.unsqueeze(1).float()).expand(feat_1.size())
    tar_feat_1 = nn.Upsample(size_f, mode='nearest')(target0.unsqueeze(1).float()).expand(feat_2.size())
    center_feat_S = feat_1.clone()
    center_feat_T = feat_2.clone()
    for i in range(19):
        mask_feat_0 = (tar_feat_0 == i).float()
        mask_feat_1 = (tar_feat_1 == i).float()
        center_feat_S = (1 - mask_feat_0) * center_feat_S + mask_feat_0 * ((mask_feat_0 * feat_1).sum(-1).sum(-1) / (mask_feat_0.sum(-1).sum(-1) + 1e-6)).unsqueeze(-1).unsqueeze(-1)
        center_feat_T = (1 - mask_feat_1) * center_feat_T + mask_feat_1 * ((mask_feat_1 * feat_2).sum(-1).sum(-1) / (mask_feat_1.sum(-1).sum(-1) + 1e-6)).unsqueeze(-1).unsqueeze(-1)

    center_feat_S = ((1 - 1 / (epoch + 1)) * cen_bank_1 + center_feat_S * 1 / (epoch + 1)) * 0.5
    center_feat_T = ((1 - 1 / (epoch + 1)) * cen_bank_2 + center_feat_T * 1 / (epoch + 1)) * 0.5
    # cosinesimilarity along C
    cos = nn.CosineSimilarity(dim=1)
    pcsim_feat_S = cos(feat_1, center_feat_S)
    pcsim_feat_T = cos(feat_2, center_feat_T)

    # FA
    mse = nn.MSELoss()
    loss = mse(pcsim_feat_S, pcsim_feat_T)
    # fa sfmx
    # kl_loss = nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=True)
    # loss = kl_loss(F.log_softmax(pcsim_feat_T), F.softmax(pcsim_feat_S).detach())
    # center sfmx
    # kl_loss = nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=True)
    # loss = kl_loss(F.log_softmax(center_feat_S), F.softmax(center_feat_T).detach())
    # center 
    # kl_loss = nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=True)
    # loss = kl_loss(pcsim_feat_S, pcsim_feat_T.detach())
    return loss, center_feat_S, center_feat_T

这段代码定义了一个函数IFV,用于计算图像特征向量之间的距离并进行特征融合。让我逐步解释:

  1. 函数接受以下参数:

    • feat_1:源图像的特征向量。
    • feat_2:目标图像的特征向量。
    • target0:源图像的标签。
    • target1:目标图像的标签。
    • cen_bank_1:源图像的特征中心。
    • cen_bank_2:目标图像的特征中心。
    • epoch:当前的训练轮数。
  2. 首先,根据源图像和目标图像的尺寸,使用双线性插值将目标图像的标签调整为与源图像相同的尺寸。

  3. 接下来,计算源图像和目标图像各个类别的特征中心。

    • 对于每个类别,通过计算在该类别上的像素的平均特征向量来得到特征中心。
    • 这里使用了一个循环,对于19个类别分别计算特征中心。
  4. 更新源图像和目标图像的特征中心:

    • 使用滑动平均的方式更新特征中心,以适应模型的训练。
    • 公式中的 (1 - 1 / (epoch + 1)) 部分用于控制滑动平均的权重。
    • 最后乘以 0.5 是为了进行归一化操作。
  5. 计算源图像和目标图像特征向量之间的余弦相似度。

    • 使用 nn.CosineSimilarity 计算特征向量的余弦相似度。
  6. 计算损失:

    • 使用均方误差(MSE)损失函数计算余弦相似度的误差。
  7. 返回损失值以及更新后的源图像和目标图像的特征中心。

def adjust_learning_rate_poly(optimizer, epoch, num_epochs, base_lr, power):
    lr = base_lr * (1-epoch/num_epochs)**power
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

这段代码定义了一个学习率调整函数 adjust_learning_rate_poly,用于根据当前训练的轮次来动态地调整学习率。下面是对代码的解释:

  1. 函数接收以下参数:

    • optimizer:优化器对象,用于更新模型的参数。
    • epoch:当前训练轮次。
    • num_epochs:总的训练轮次。
    • base_lr:基础学习率,即初始的学习率。
    • power:多项式调整的幂次。
  2. 计算当前轮次下的学习率:

    • 使用多项式衰减策略,学习率会随着训练轮次的增加而逐渐减小。
    • 公式为 lr = base_lr * (1-epoch/num_epochs)**power,其中 ** 表示幂运算。
  3. 更新优化器中每个参数组的学习率:

    • 历优化器中的每个参数组,将学习率更新为计算得到1的lr 值。
  4. 返回当前轮次下的学习率 lr

域自适应代码

然后他的域自适应代码实在是太抽象了,他需要先提供一个伪标签影像,没错他这个代码我没有,在运行前需要先有伪标签影像,然后在网络主体我没看他自己定义的model,用了Segformer网络来进行训练,我。。。。。。。。。,现在感觉他没有放真正的work来Github,吐了,如果要改先将分布式训练改为单卡训练,我这里把那个syn_cfa.py改为单卡的代码先仍出来:

import warnings
warnings.filterwarnings('ignore')

import argparse
import random
import os
import time,datetime
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from torch.utils.tensorboard import SummaryWriter

from metrics.compute_iou import fast_hist, per_class_iu
import numpy as np
from torchvision import transforms
import torch.nn as nn
import torch
from models.segformer.segformer import Seg
from torch.utils import data

from dataset.adaption.sp13_dataset import synpass13DataSet
from dataset.adaption.dp13_dataset import densepass13TestDataSet
from dataset.adaption.dp13_dataset_ss import densepass13DataSet

import tqdm

#NAME_CLASSES = ["road", "sidewalk", "building", "wall", "fence", "pole","light","sign","vegetation","terrain","sky","person","rider","car","truck","bus","train","motocycle","bicycle"]
NAME_CLASSES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
                'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'car']

def IFV(feat_1, feat_2, target0, target1, cen_bank_1, cen_bank_2, epoch):
    #feat_T.detach()
    size_f = (feat_1.shape[2], feat_1.shape[3])
    tar_feat_0 = nn.Upsample(size_f, mode='nearest')(target1.unsqueeze(1).float()).expand(feat_1.size())
    tar_feat_1 = nn.Upsample(size_f, mode='nearest')(target0.unsqueeze(1).float()).expand(feat_2.size())
    center_feat_S = feat_1.clone()
    center_feat_T = feat_2.clone()
    for i in range(19):
        mask_feat_0 = (tar_feat_0 == i).float()
        mask_feat_1 = (tar_feat_1 == i).float()
        center_feat_S = (1 - mask_feat_0) * center_feat_S + mask_feat_0 * ((mask_feat_0 * feat_1).sum(-1).sum(-1) / (mask_feat_0.sum(-1).sum(-1) + 1e-6)).unsqueeze(-1).unsqueeze(-1)
        center_feat_T = (1 - mask_feat_1) * center_feat_T + mask_feat_1 * ((mask_feat_1 * feat_2).sum(-1).sum(-1) / (mask_feat_1.sum(-1).sum(-1) + 1e-6)).unsqueeze(-1).unsqueeze(-1)

    center_feat_S = ((1 - 1 / (epoch + 1)) * cen_bank_1 + center_feat_S * 1 / (epoch + 1)) * 0.5
    center_feat_T = ((1 - 1 / (epoch + 1)) * cen_bank_2 + center_feat_T * 1 / (epoch + 1)) * 0.5
    # cosinesimilarity along C
    cos = nn.CosineSimilarity(dim=1)
    pcsim_feat_S = cos(feat_1, center_feat_S)
    pcsim_feat_T = cos(feat_2, center_feat_T)

    # FA
    mse = nn.MSELoss()
    loss = mse(pcsim_feat_S, pcsim_feat_T)
    return loss, center_feat_S, center_feat_T

def adjust_learning_rate_poly(optimizer, epoch, num_epochs, base_lr, power):
    lr = base_lr * (1-epoch/num_epochs)**power
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def main():
    parser = argparse.ArgumentParser(description='pytorch implemention')
    parser.add_argument('--batch-size', type=int, default=1, metavar='N',
                        help='input batch size for training (default: 6)')
    parser.add_argument('--iterations', type=int, default=30000, metavar='N',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=6e-5, metavar='LR',
                        help='learning rate (default: 6e-5)')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save_root', default = '',
                        help='Please add your model save directory')
    parser.add_argument('--exp_name', default = '',
                        help='')
    parser.add_argument('--backbone',  type=str, default = '',
                        help='')
    parser.add_argument('--sup_set', type=str, default='train', help='supervised training set')
    parser.add_argument('--cutmix', default =False, help='cutmix')
    #================================hyper parameters================================#
    parser.add_argument('--alpha', type=float, default =0.5, help='alpha')
    parser.add_argument('--lamda', type=float, default =0.001, help='lamda')
    parser.add_argument('--dis_lr', type=float, default =0.001, help='dis_lr')
    #================================================================================#
    args = parser.parse_args()
    best_performance_dp, best_performance_sp = 0.0, 0.0

    args.save_root = "/home/yaowen-chang/DATR/DATR-main/"
    args.exp_name = "result"
    save_path = "{}{}".format(args.save_root,args.exp_name)
    cur_time = str(datetime.datetime.now().strftime("%y%m%d-%H:%M:%S"))
    writer = SummaryWriter(log_dir=save_path)

    if os.path.exists(save_path):
        pass
    else:
        os.makedirs(save_path)

    torch.cuda.set_device(0)  # Set the GPU device to be used
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # SynPASS dataset
    syn_h, syn_w = 2048, 400
    root_syn = '/media/yaowen-chang/新加卷/Trans4PASS/trans4pass-data/SynPASS/SynPASS'
    list_path = '/home/yaowen-chang/DATR/DATR-main/dataset/adaption/synpass_list/train.txt'
    syn_dataset = synpass13DataSet(root_syn, list_path, crop_size=(syn_h, syn_w), set='train')
    syn_train_loader = torch.utils.data.DataLoader(syn_dataset, batch_size=args.batch_size, shuffle=True, num_workers=12, pin_memory=True)

    val_list = '/home/yaowen-chang/DATR/DATR-main/dataset/adaption/synpass_list/val.txt'
    syn_val = synpass13DataSet(root_syn, val_list, crop_size=(1024,512))
    syn_val_loader = torch.utils.data.DataLoader(syn_val, batch_size=1, num_workers=1, pin_memory=True)

    test_list = '/home/yaowen-chang/DATR/DATR-main/dataset/adaption/synpass_list/test.txt'
    syn_test = synpass13DataSet(root_syn, test_list, crop_size=(syn_h, syn_w))
    syn_test_loader = torch.utils.data.DataLoader(syn_test, batch_size=1, num_workers=1, pin_memory=True)

    # DensePASS dataset
    root_dp = '/media/yaowen-chang/新加卷/Trans4PASS/trans4pass-data/DensePASS'
    list_path = '/home/yaowen-chang/DATR/DATR-main/dataset/adaption/densepass_list/val.txt'
    train_root = '/media/yaowen-chang/新加卷/Trans4PASS/trans4pass-data/DensePASS'
    train_list = '/home/yaowen-chang/DATR/DATR-main/dataset/adaption/densepass_list/train.txt'
    pass_train = densepass13DataSet(train_root, train_list, crop_size=(2048,400))
    pass_train_loader = torch.utils.data.DataLoader(pass_train, batch_size=args.batch_size, shuffle=True, num_workers=12, pin_memory=True)

    pass_dataset = densepass13TestDataSet(root_dp, list_path, crop_size=(2048,400), set='val')
    testloader = torch.utils.data.DataLoader(pass_dataset, batch_size=1, shuffle=True, num_workers=1, pin_memory=True)

    num_classes = 13
    NUM_CLASSES = 13
    w, h = 2048, 400

    # Models
    # Models
    args.backbone = 'b1'
    model1 = Seg(num_classes=num_classes, phi= args.backbone, pretrained=True)
    model_path = "/home/yaowen-chang/DATR/DATR-main/models/segformer/backbone/mit_b1.pth"
    model1.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False)
    print('Model is', args.backbone)
    print('Load Model from', model_path)

    # 移除多 GPU 支持代码
    model1 = model1.to(device)

    # Iterative dataloader
    syn_length = len(syn_train_loader)
    pass_length = len(pass_train_loader)

    print(f'Panoramic Dataset length:{len(pass_train)};')
    print(f'SynPASS Dataset length:{len(syn_dataset)};')


    # Training Details
    criterion_sup = nn.CrossEntropyLoss(reduction='mean', ignore_index=255)
    optimizer1 = optim.AdamW(model1.parameters(), lr=args.lr, weight_decay=0.0001)

    it = 1
    epoch = 1
    cen_bank_1, cen_bank_2 = torch.zeros(1, 512, 13, 64).to(device), torch.zeros(1, 512, 13, 64).to(device)
    syn_sup_loader = iter(syn_train_loader)  # 初始化,不初始化给我报错
    pass_img_loader = iter(pass_train_loader)

    for it in range(1, args.iterations + 1):
        if it % syn_length == 0:
            syn_sup_loader = iter(syn_train_loader)
        if it % pass_length == 0:
            pass_img_loader = iter(pass_train_loader)

        # 获取SynPASS数据
        try:
            s_img, s_gt, _, _ = next(syn_sup_loader)
        except StopIteration:
            syn_sup_loader = iter(syn_train_loader)
            s_img, s_gt, _, _ = next(syn_sup_loader)

        # 获取Panoramic数据
        try:
            p_img, p_gt, _, _ = next(pass_img_loader)
        except StopIteration:
            pass_img_loader = iter(pass_train_loader)
            p_img, p_gt, _, _ = next(pass_img_loader)

        s_img, s_gt = s_img.to(device), s_gt.to(device)
        p_img, p_gt = p_img.to(device), p_gt.to(device)

        syn_pred, syn_feat = model1(s_img)
        pass_pred, pass_feat = model1(p_img)

        loss_sup_1 = criterion_sup(syn_pred, s_gt)
        loss_sup_2 = criterion_sup(pass_pred, p_gt)

        loss_fa_3, cen_1, cen_2 = IFV(pass_feat[3], syn_feat[3], p_gt, s_gt, cen_bank_1, cen_bank_2, epoch)
        cen_bank_1 = cen_1.detach().clone()
        cen_bank_2 = cen_2.detach().clone()
        loss_fa = loss_fa_3

        loss_1 = loss_sup_1 + loss_sup_2 + loss_fa_3

        optimizer1.zero_grad()
        loss_1.backward()
        optimizer1.step()

        base_lr = args.lr
        if it <= 1500:
            lr_ = base_lr * (it / 1500)
            for param_group in optimizer1.param_groups:
                param_group['lr'] = lr_
        else:
            lr_ = adjust_learning_rate_poly(optimizer1, it - 1500, it, args.lr, 1)

        if it % pass_length == 0 or it == 1:
            print(f'iter:{it};Model1 Total loss: {loss_1:.4f}')
            print(f'iter:{it};Model1 Sup loss: {loss_sup_1:.4f}')
            print(f'iter:{it};Model1 Pseudo Sup Loss: {loss_sup_2:.4f}')

            with torch.no_grad():
                print(f'[Validation it: {it}] lr: {lr_}')
                model1.eval()
                best_performance_dp = validation(num_classes, NUM_CLASSES, NAME_CLASSES, device, testloader, model1,
                                                 best_performance_dp, save_path, epoch, 'densepass')
                epoch += 1
                model1.train()

    # validation 函数不需要修改
def validation(num_classes, NUM_CLASSES, NAME_CLASSES, device, testloader, model1, best_performance, save_path, epoch, name):
    writer = SummaryWriter(log_dir=save_path)
    hist = np.zeros((num_classes, num_classes))
    for index, batch in enumerate(testloader):
        # if index % 100 == 0:
        #     print ('%d processd' % index)
        image, label, _, _ = batch
        image, label = image.to(device), label.to(device)
        with torch.no_grad():
            output, _ = model1(image)
        output = torch.argmax(output, 1).squeeze(0).cpu().data.numpy()

        label = label.cpu().data[0].numpy()
        hist += fast_hist(label.flatten(), output.flatten(), num_classes)

    mIoUs = per_class_iu(hist)
    for ind_class in range(num_classes):
        print('===>{:<15}:\t{}'.format(NAME_CLASSES[ind_class], str(round(mIoUs[ind_class] * 100, 2))))
    bestIoU = round(np.nanmean(mIoUs) * 100, 2)
    print('===> mIoU: ' + str(bestIoU))
    if name == 'densepass':
        if bestIoU > best_performance:
            best_performance = bestIoU
        torch.save(model1.module.state_dict(),save_path+"/best_densepass.pth")
        print('epoch:',epoch,name,'val_mIoU',bestIoU, 'best_model:', best_performance)
        writer.add_scalar('[DensePASS] val_mIOU:',bestIoU,epoch)
    if name == 'synpass':
        if bestIoU > best_performance:
            best_performance = bestIoU
        torch.save(model1.module.state_dict(),save_path+"/best_synpass.pth")
        print('epoch:',epoch,name,'val_mIoU',bestIoU, 'best_model:', best_performance)
        writer.add_scalar('[SynPASS] val_mIOU:',bestIoU,epoch)
    return best_performance

if __name__ == '__main__':
    main()

这个实验我感觉他在缺少环境需要自己按照他的想法来复现,这篇文章作者在github缺少了核心的环境配置和能执行的代码!!!

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值