STANet代码解读models部分

init.py

import importlib
from models.base_model import BaseModel


def find_model_using_name(model_name):
    """Import the module "models/[model_name]_model.py".

    In the file, the class called DatasetNameModel() will
    be instantiated. It has to be a subclass of BaseModel,
    and it is case-insensitive.
    """
    model_filename = "models." + model_name + "_model"
    modellib = importlib.import_module(model_filename)
    model = None
    target_model_name = model_name.replace('_', '') + 'model'
    for name, cls in modellib.__dict__.items():
        if name.lower() == target_model_name.lower() \
           and issubclass(cls, BaseModel):
            model = cls

    if model is None:
        print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
        exit(0)

    return model


def get_option_setter(model_name):
    """Return the static method <modify_commandline_options> of the model class."""
    model_class = find_model_using_name(model_name)
    return model_class.modify_commandline_options


def create_model(opt):
    """Create a model given the option.

    This function warps the class CustomDatasetDataLoader.
    This is the main interface between this package and 'train.py'/'test.py'

    Example:
        >>> from models import create_model
        >>> model = create_model(opt)
    """
    model = find_model_using_name(opt.model)
    instance = model(opt)
    print("model [%s] was created" % type(instance).__name__)
    return instance

这段代码实现了一个用于创建模型的函数create_model,主要包含以下几个部分:

find_model_using_name:根据模型名称找到对应的模型类。首先根据模型名称构造模型文件名,然后动态加载模型文件对应的模块,并遍历模块中的所有类,找到类名与模型名称匹配的模型类。

get_option_setter:返回模型类的modify_commandline_options静态方法,用于修改命令行参数。

create_model:根据命令行参数opt创建模型实例。首先调用find_model_using_name函数找到对应的模型类,然后创建该模型类的实例instance,并返回。

backbone.py

# coding: utf-8
import torch.nn as nn
import torch
from .mynet3 import F_mynet3
from .BAM import BAM
from .PAM2 import PAM as PAM



def define_F(in_c, f_c, type='unet'):
    if type == 'mynet3':
        print("using mynet3 backbone")
        return F_mynet3(backbone='resnet18', in_c=in_c,f_c=f_c, output_stride=32)
    else:
        NotImplementedError('no such F type!')

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)



class CDSA(nn.Module):
    """self attention module for change detection

    """
    def __init__(self, in_c, ds=1, mode='BAM'):
        super(CDSA, self).__init__()
        self.in_C = in_c
        self.ds = ds
        print('ds: ',self.ds)
        self.mode = mode
        if self.mode == 'BAM':
            self.Self_Att = BAM(self.in_C, ds=self.ds)
        elif self.mode == 'PAM':
            self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1,2,4,8],ds=self.ds)
        self.apply(weights_init)

    def forward(self, x1, x2):
        height = x1.shape[3]
        x = torch.cat((x1, x2), 3)
        x = self.Self_Att(x)
        return x[:,:,:,0:height], x[:,:,:,height:]

这段代码实现了一个用于变化检测的自注意力模块CDSA(Change Detection Self Attention)。主要包含以下几个部分:
define_F:定义一个特征提取网络F,根据type参数的不同可以选择不同的backbone,如果type'mynet3'则使用F_mynet3网络。

weights_init:初始化网络权重的函数,用于给网络的卷积层和批归一化层设置初始值(卷积层为均值为0,标准差为0.02的正态分布,批归一化层的权重为1,偏置为0)。

__init__:构造函数,初始化了一些变量,如输入通道数in_c、降采样因子ds、自注意力机制的类型mode等。然后根据mode的不同选择不同的自注意力模块,如果mode'BAM'则使用BAM模块,如果mode'PAM'则使用PAM模块。最后使用weights_init函数初始化网络权重。

forward:前向传播函数,输入两个特征图x1x2,将它们在通道维度上进行拼接,然后将拼接后的特征图传入Self_Att模块中,得到输出特征图x。最后将x沿着通道维度进行分离,分别得到x1x2两个特征图,返回它们。

BAM.py

import torch
import torch.nn.functional as F
from torch import nn


class BAM(nn.Module):
    """ Basic self-attention module
    """

    def __init__(self, in_dim, ds=8, activation=nn.ReLU):
        super(BAM, self).__init__()
        self.chanel_in = in_dim
        self.key_channel = self.chanel_in //8
        self.activation = activation
        self.ds = ds  #
        self.pool = nn.AvgPool2d(self.ds)
        print('ds: ',ds)
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)  #

    def forward(self, input):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        x = self.pool(input)
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)  # B X C X (N)/(ds*ds)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)  # B X C x (*W*H)/(ds*ds)
        energy = torch.bmm(proj_query, proj_key)  # transpose check
        energy = (self.key_channel**-.5) * energy

        attention = self.softmax(energy)  # BX (N) X (N)/(ds*ds)/(ds*ds)

        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)  # B X C X N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, width, height)

        out = F.interpolate(out, [width*self.ds,height*self.ds])
        out = out + input

        return out

这段代码实现了一个基础的自注意力模块(BAM)。主要包含以下几个部分:

__init__:构造函数,初始化了一些变量,如输入通道数in_dim、降采样因子ds、激活函数activation等。然后定义了三个卷积层,分别对应QueryKeyValue,其中QueryKey的输出通道数为输入通道数的1/8,Value的输出通道数为输入通道数。最后定义了可学习的参数gammasoftmax函数。

forward:前向传播函数,主要功能是将输入数据x先经过一个平均池化层进行降采样,然后将降采样后的特征图x分别传入QueryKeyValue三个卷积层中,得到对应的特征张量proj_queryproj_keyproj_value。然后计算proj_queryproj_key的转置矩阵相乘,得到能量矩阵energy。接下来,将energy除以key_channel的平方根,再使用softmax函数得到注意力矩阵attention。最后,将proj_valueattention矩阵相乘得到输出张量out,经过插值调整大小后再加上输入张量x,得到最终输出。

CDF0.py

import torch
import itertools
from .base_model import BaseModel
from . import backbone
import torch.nn.functional as F
from . import loss


class CDF0Model(BaseModel):
    """
    change detection module:
    feature extractor
    contrastive loss
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        return parser

    def __init__(self, opt):
        BaseModel.__init__(self, opt)
        self.istest = opt.istest
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['f']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['A', 'B', 'L', 'pred_L_show']  # visualizations for A and B
        if self.istest:
            self.visual_names = ['A', 'B', 'pred_L_show']
        self.visual_features = ['feat_A', 'feat_B']
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        if self.isTrain:
            self.model_names = ['F']
        else:  # during test time, only load Gs
            self.model_names = ['F']
        self.ds=1
        # define networks (both Generators and discriminators)
        self.n_class = 2
        self.netF = backbone.define_F(in_c=3, f_c=opt.f_c, type=opt.arch).to(self.device)

        if self.isTrain:
            # define loss functions
            self.criterionF = loss.BCL()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netF.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        self.A = input['A'].to(self.device)
        self.B = input['B'].to(self.device)
        if not self.istest:
            self.L = input['L'].to(self.device).long()
        self.image_paths = input['A_paths']
        if self.isTrain:
            self.L_s = self.L.float()
            self.L_s = F.interpolate(self.L_s, size=torch.Size([self.A.shape
  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
CycleGAN是一个无监督的图像转换模型,可以将一种领域的图像转换成另一种领域的图像,而无需手动标注数据集。其核心思想是通过两个生成器和两个判别器,来实现两个领域之间的图像转换。下面我们来看一下CycleGAN的源代码解读。 CycleGAN的主要代码在`models`文件夹下,其中`cycle_gan_model.py`定义了CycleGAN的模型结构,`networks.py`定义了生成器和判别器的网络结构。其中生成器采用U-Net结构,判别器采用PatchGAN结构。`options`文件夹下的`base_options.py`定义了模型的一些基本参数,包括训练数据路径、模型保存路径、学习率等。`train_options.py`继承了`base_options.py`,并添加了一些训练相关的参数,比如迭代次数、是否使用L1损失等。`test_options.py`同样继承了`base_options.py`,并添加了一些测试相关的参数,比如测试数据路径、输出结果路径等。 在`train.py`文件中,我们可以看到CycleGAN的训练流程。首先定义了模型、数据加载器、优化器等,然后开始训练。在训练过程中,先通过生成器将A领域的图片转换成B领域的图片,然后将转换后的图片与B领域的真实图片送入判别器,计算判别器的损失。同时,也计算生成器的损失,包括对抗损失、循环一致性损失和L1损失。最后通过反向传播更新生成器和判别器的参数。 在`test.py`文件中,我们可以看到CycleGAN的测试流程。首先定义了模型和数据加载器,然后通过生成器将A领域的图片转换成B领域的图片,并将转换后的图片保存到输出结果路径中。 总之,CycleGAN的源代码实现了一个完整的无监督图像转换模型,包括模型结构、数据加载、训练和测试流程。如果想要深入了解CycleGAN,可以从源代码入手,逐步理解其实现原理。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值