init.py
import importlib
from models.base_model import BaseModel
def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
model_filename = "models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() \
and issubclass(cls, BaseModel):
model = cls
if model is None:
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def get_option_setter(model_name):
"""Return the static method <modify_commandline_options> of the model class."""
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
"""Create a model given the option.
This function warps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from models import create_model
>>> model = create_model(opt)
"""
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % type(instance).__name__)
return instance
这段代码实现了一个用于创建模型的函数create_model
,主要包含以下几个部分:
find_model_using_name
:根据模型名称找到对应的模型类。首先根据模型名称构造模型文件名,然后动态加载模型文件对应的模块,并遍历模块中的所有类,找到类名与模型名称匹配的模型类。
get_option_setter
:返回模型类的modify_commandline_options
静态方法,用于修改命令行参数。
create_model
:根据命令行参数opt
创建模型实例。首先调用find_model_using_name
函数找到对应的模型类,然后创建该模型类的实例instance
,并返回。
backbone.py
# coding: utf-8
import torch.nn as nn
import torch
from .mynet3 import F_mynet3
from .BAM import BAM
from .PAM2 import PAM as PAM
def define_F(in_c, f_c, type='unet'):
if type == 'mynet3':
print("using mynet3 backbone")
return F_mynet3(backbone='resnet18', in_c=in_c,f_c=f_c, output_stride=32)
else:
NotImplementedError('no such F type!')
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class CDSA(nn.Module):
"""self attention module for change detection
"""
def __init__(self, in_c, ds=1, mode='BAM'):
super(CDSA, self).__init__()
self.in_C = in_c
self.ds = ds
print('ds: ',self.ds)
self.mode = mode
if self.mode == 'BAM':
self.Self_Att = BAM(self.in_C, ds=self.ds)
elif self.mode == 'PAM':
self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1,2,4,8],ds=self.ds)
self.apply(weights_init)
def forward(self, x1, x2):
height = x1.shape[3]
x = torch.cat((x1, x2), 3)
x = self.Self_Att(x)
return x[:,:,:,0:height], x[:,:,:,height:]
这段代码实现了一个用于变化检测的自注意力模块CDSA(Change Detection Self Attention)。主要包含以下几个部分:
define_F
:定义一个特征提取网络F,根据type参数的不同可以选择不同的backbone
,如果type
为'mynet3'
则使用F_mynet3
网络。
weights_init
:初始化网络权重的函数,用于给网络的卷积层和批归一化层设置初始值(卷积层为均值为0,标准差为0.02的正态分布,批归一化层的权重为1,偏置为0)。
__init__
:构造函数,初始化了一些变量,如输入通道数in_c
、降采样因子ds
、自注意力机制的类型mode
等。然后根据mode
的不同选择不同的自注意力模块,如果mode
为'BAM'
则使用BAM模块,如果mode
为'PAM'
则使用PAM模块。最后使用weights_init
函数初始化网络权重。
forward
:前向传播函数,输入两个特征图x1
、x2
,将它们在通道维度上进行拼接,然后将拼接后的特征图传入Self_Att
模块中,得到输出特征图x
。最后将x
沿着通道维度进行分离,分别得到x1
、x2
两个特征图,返回它们。
BAM.py
import torch
import torch.nn.functional as F
from torch import nn
class BAM(nn.Module):
""" Basic self-attention module
"""
def __init__(self, in_dim, ds=8, activation=nn.ReLU):
super(BAM, self).__init__()
self.chanel_in = in_dim
self.key_channel = self.chanel_in //8
self.activation = activation
self.ds = ds #
self.pool = nn.AvgPool2d(self.ds)
print('ds: ',ds)
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self, input):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
x = self.pool(input)
m_batchsize, C, width, height = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds)
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds)
energy = torch.bmm(proj_query, proj_key) # transpose check
energy = (self.key_channel**-.5) * energy
attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds)
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, width, height)
out = F.interpolate(out, [width*self.ds,height*self.ds])
out = out + input
return out
这段代码实现了一个基础的自注意力模块(BAM)。主要包含以下几个部分:
__init__
:构造函数,初始化了一些变量,如输入通道数in_dim
、降采样因子ds
、激活函数activation
等。然后定义了三个卷积层,分别对应Query
、Key
、Value
,其中Query
和Key
的输出通道数为输入通道数的1/8,Value
的输出通道数为输入通道数。最后定义了可学习的参数gamma
和softmax
函数。
forward
:前向传播函数,主要功能是将输入数据x
先经过一个平均池化层进行降采样,然后将降采样后的特征图x
分别传入Query
、Key
、Value
三个卷积层中,得到对应的特征张量proj_query
、proj_key
、proj_value
。然后计算proj_query
和proj_key
的转置矩阵相乘,得到能量矩阵energy
。接下来,将energy
除以key_channel
的平方根,再使用softmax
函数得到注意力矩阵attention
。最后,将proj_value
和attention
矩阵相乘得到输出张量out
,经过插值调整大小后再加上输入张量x
,得到最终输出。
CDF0.py
import torch
import itertools
from .base_model import BaseModel
from . import backbone
import torch.nn.functional as F
from . import loss
class CDF0Model(BaseModel):
"""
change detection module:
feature extractor
contrastive loss
"""
@staticmethod
def modify_commandline_options(parser, is_train=True):
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.istest = opt.istest
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['f']
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
self.visual_names = ['A', 'B', 'L', 'pred_L_show'] # visualizations for A and B
if self.istest:
self.visual_names = ['A', 'B', 'pred_L_show']
self.visual_features = ['feat_A', 'feat_B']
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
if self.isTrain:
self.model_names = ['F']
else: # during test time, only load Gs
self.model_names = ['F']
self.ds=1
# define networks (both Generators and discriminators)
self.n_class = 2
self.netF = backbone.define_F(in_c=3, f_c=opt.f_c, type=opt.arch).to(self.device)
if self.isTrain:
# define loss functions
self.criterionF = loss.BCL()
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netF.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
self.A = input['A'].to(self.device)
self.B = input['B'].to(self.device)
if not self.istest:
self.L = input['L'].to(self.device).long()
self.image_paths = input['A_paths']
if self.isTrain:
self.L_s = self.L.float()
self.L_s = F.interpolate(self.L_s, size=torch.Size([self.A.shape