"""
该代码实现一个带有跳跃连接的编码器和解码器神经网络结构
1:定义一个函数skip,接受一系列参数来构建网络
2:通过assert 确保参数num_channels_down、num_channels_up、num_channels_skip的长度相同,否则引发异常
注:这三个参数用于指定通道的参数.
num_channels_down代表下采样路径中每个卷积层的输出通道数,num_channels_down=[16, 32, 64, 128, 128]表示第一层下采样路径输出16个通道,第二层输出32个通道,以此类推,最后一层输出128个通道。
num_channels_up代表上采样路径中每个卷积层的输出通道数
num_channels_skip代表跳跃连接部分的通道数
在神经网络中,通道(Channel)是指特征图的维度或深度。
在卷积神经网络中,每个卷积层会输出一定数量的特征图,每个特征图对应一个通道。这些通道可以被看作是网络学习到的不同特征的表示。例如,第一个卷积层可能输出16个特征图,每个特征图对应一个通道,代表了16种不同的低级特征。
通道的数量决定了网络能够捕捉和表示的特征种类和复杂性。增加通道数可以提高网络的表达能力,但也会增加模型的参数量和计算复杂度。
3:使用nn.Sequential()创建一个空的神经网络模型model
4:循环创建网络结构,每次循环包含一个下采样路径和一个上采样路径。
·在下采样路径中,使用卷积层进行特征提取,并添加批归一化和激活函数。
·在上采样路径中,使用上采样模块对特征进行上采样,然后与跳跃连接部分的特征进行拼接,再进行卷积、批归一化和激活函数操作。
5:根据是否需要使用1×1卷积层进行上采样,决定是否添加该层。
6:最后使用卷积层进行输出,并可选择添加Sigmoid函数进行归一化处理。
返回构建好的模型。
总结:该网络结构可以用于图像重建和去噪等任务,通过跳跃连接可以提高网络的性能和准确性。
"""
import torch
import torch.nn as nn
from .common import *
def skip(
num_input_channels=2, num_output_channels=3,
num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4],
filter_size_down=3, filter_size_up=3, filter_skip_size=1,
need_sigmoid=True, need_bias=True,
pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU',
need1x1_up=True):
"""Assembles encoder-decoder with skip connections.
Arguments:
act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU)
pad (string): zero|reflection (default: 'zero')
upsample_mode (string): 'nearest|bilinear' (default: 'nearest')
downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride')
"""
# 作用:为了确保网络的不同层次中特征图的通道数都被正确地设置和匹配
assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip)
# 获取网络的总层数
n_scales = len(num_channels_down)
if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) :
upsample_mode = [upsample_mode]*n_scales
if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)):
downsample_mode = [downsample_mode]*n_scales
if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) :
filter_size_down = [filter_size_down]*n_scales
if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) :
filter_size_up = [filter_size_up]*n_scales
# 因为python中的列表的索引从0开始 因此最后一层的索引应该为总层数减一:n_scales - 1 在模型训练和推理时,这个索引用于选择最后一层的特征图作为网络输出
last_scale = n_scales - 1
cur_depth = None
# 利用pytorch的nn.Sequential()来构建一个层次化的神经网络
model = nn.Sequential()
# 创建一个指向 model 的临时模型model_tmp 用于添加子模块
model_tmp = model
input_depth = num_input_channels
# 根据num_channels_down层数,循环创建神经网络的层次结构
for i in range(len(num_channels_down)):
deeper = nn.Sequential()
skip = nn.Sequential()
# skip 和 deeper 是两个张量,它们在维度1上具有相同的大小,使用 Concat(1, skip, deeper) 将它们在维度1上进行拼接,得到一个新的张量。
# 例如,如果 skip 的形状是 (batch_size, num_channels_skip[i], height, width)
# deeper 的形状是 (batch_size, num_channels_down[i], height, width)
# 那么 Concat(1, skip, deeper) 的结果形状将是 (batch_size, num_channels_skip[i] + num_channels_down[i], height, width)
# 其中,num_channels_skip[i] + num_channels_down[i] 表示沿着维度1上的通道数之和。
if num_channels_skip[i] != 0:
model_tmp.add(Concat(1, skip, deeper))
else:
model_tmp.add(deeper)
# 是将一个批归一化层添加到指向 model 的临时模型model_tmp中
# num_channels_skip[i]表示跳跃连接部分skip通道数、num_channels_up[i + 1] 如果 i 不是最后一层,则表示上采样路径中当前层的输出通道数、num_channels_down[i] 如果 i 是最后一层,则表示下采样路径中当前层的输出通道数。
# 下采样是指从输入图像中提取出一部分信息来,减少图像中的数据量,从而得到一个更小的图像 。 上采样则是指在图像中添加新信息使其变大,从而得到一个更大的图像。
model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i])))
if num_channels_skip[i] != 0:
# 创建一个卷积层,input_depth 是输入通道数,num_channels_skip[i] 是跳跃连接层的输出通道数,filter_skip_size 是卷积核的大小,bias 和 pad 分别表示是否添加偏置项和是否进行填充
# 然后,将卷积层连接到跳跃连接模块中,按照顺序添加批归一化层和激活函数。
skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad))
# 通过 skip.add(bn(num_channels_skip[i])) 和 skip.add(act(act_fun)) 分别添加批归一化层和激活函数
# bn(num_channels_skip[i]) 根据前面提到的 bn() 函数创建一个具有 num_channels_skip[i] 输出通道数的批归一化模块。
# 意味着在卷积层的输出结果上依次应用批归一化和激活函数。
# 在每个卷积层的输出上应用批归一化,可以通过对每个通道维度的特征进行标准化和调整来规范化网络的输出。这有助于提高网络的表达能力和鲁棒性。
# 激活函数则用于引入非线性性质,它对卷积层输出的特征进行非线性映射,以增加网络的表达能力。常见的激活函数有ReLU、sigmoid、tanh等。通过将卷积层的输出结果输入到激活函数模块列表中,可以对特征进行非线性变换,从而更好地捕捉数据中的复杂特征。
# 因此,将卷积层连接到批归一化和激活函数模块列表,是为了在深度卷积神经网络中引入归一化和非线性变换,提高网络性能并增强特征的表达能力。
skip.add(bn(num_channels_skip[i]))
skip.add(act(act_fun))
# skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part))
# 通过 conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i]) 创建一个卷积层,并将其加入到 deeper 中
# 卷积层的输入通道数为 input_depth,输出通道数为 num_channels_down[i],卷积核大小为 filter_size_down[i],步长为2,表示对输入图像进行降采样。bias 和 pad 分别表示是否添加偏置项和是否进行填充,downsample_mode[i] 表示降采样的方式
# 可以选择 "stride"(使用步长进行降采样)或 "avg"(使用平均池化进行降采样)。
deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i]))
# 依次添加批归一化层和激活函数
deeper.add(bn(num_channels_down[i]))
deeper.add(act(act_fun))
# 通过 conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad) 创建另一个卷积层,并将其加入到 deeper 中。
# 这个卷积层的输入和输出通道数都为 num_channels_down[i],卷积核大小为 filter_size_down[i],bias 和 pad 分别表示是否添加偏置项和是否进行填充。
deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad))
deeper.add(bn(num_channels_down[i]))
deeper.add(act(act_fun))
# 通过创建一个不包含任何层的顺序容器deeper_main 来保留之前构建好的deeper容器
deeper_main = nn.Sequential()
# 网络构建的后续部分
# 网络的后续部分进行特征上采样操作,并将上采样后的特征与之前的跳跃连接结果进行拼接。通过添加卷积、批归一化和激活函数等操作,可以进一步提取特征并增加网络的表达能力。
# 判断 i 是否等于 num_channels_down 列表的长度减1,如果成立,则表示当前层为最深层
# 如果成立,则表示当前层为最深层。在最深层时,将 k 的值设为 num_channels_down[i]。
# 如果不是最深层,则将之前构建好的 deeper_main 容器添加到 deeper 容器中,并将 k 的值设为 num_channels_up[i + 1]。这样做是为了将 deeper_main 和当前层之间进行连接。
if i == len(num_channels_down) - 1:
# The deepest
k = num_channels_down[i]
else:
deeper.add(deeper_main)
k = num_channels_up[i + 1]
# 通过 nn.Upsample 模块对输入进行上采样,将其尺寸放大2倍。
# scale_factor=2 表示上采样的尺度因子,mode=upsample_mode[i] 表示上采样的方式,可以选择 "nearest"(最近邻插值)或 "bilinear"(双线性插值)。
deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i]))
# 通过 conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad) 创建一个卷积层
# 并将其加入到 model_tmp 中。这个卷积层的输入通道数为 num_channels_skip[i] + k,输出通道数为 num_channels_up[i],卷积核大小为 filter_size_up[i],步长为1,表示对输入特征进行卷积操作。
# bias 和 pad 分别表示是否添加偏置项和是否进行填充。
model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad))
model_tmp.add(bn(num_channels_up[i]))
model_tmp.add(act(act_fun))
# 首先,判断是否需要使用 1×1 的卷积层进行上采样,如果需要则创建一个 conv 卷积层,并将其加入到 model_tmp 容器中。
# 这个卷积层的输入和输出通道数都为 num_channels_up[i],卷积核大小为 1,表示对输入特征进行卷积操作
if need1x1_up:
model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad))
model_tmp.add(bn(num_channels_up[i]))
model_tmp.add(act(act_fun))
# 重新设置当前输入图像的通道数为当前层的输出通道数,即 input_depth = num_channels_down[i],以便在下一层使用。
# 通过 conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad) 创建一个卷积层,并将其加入到 model 容器中。
# 这个卷积层的输入通道数为 num_channels_up[0](即最后一层的输出通道数),输出通道数为 num_output_channels,卷积核大小为 1,表示对输入特征进行卷积操作。
# 如果需要,可以在卷积层之后添加一个 nn.Sigmoid() 模块进行输出的归一化处理
input_depth = num_channels_down[i]
model_tmp = deeper_main
model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad))
if need_sigmoid:
model.add(nn.Sigmoid())
# 最终,通过添加卷积、批归一化、激活函数和上采样等操作,构建了一个深度卷积神经网络,用于对输入图像进行特征提取和特征重建。
return model
# common
import torch
import torch.nn as nn
import numpy as np
from .downsampler import Downsampler
def add_module(self, module):
self.add_module(str(len(self) + 1), module)
torch.nn.Module.add = add_module
class Concat(nn.Module):
def __init__(self, dim, *args):
super(Concat, self).__init__()
self.dim = dim
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def forward(self, input):
inputs = []
for module in self._modules.values():
inputs.append(module(input))
inputs_shapes2 = [x.shape[2] for x in inputs]
inputs_shapes3 = [x.shape[3] for x in inputs]
if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):
inputs_ = inputs
else:
target_shape2 = min(inputs_shapes2)
target_shape3 = min(inputs_shapes3)
inputs_ = []
for inp in inputs:
diff2 = (inp.size(2) - target_shape2) // 2
diff3 = (inp.size(3) - target_shape3) // 2
inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])
return torch.cat(inputs_, dim=self.dim)
def __len__(self):
return len(self._modules)
class GenNoise(nn.Module):
def __init__(self, dim2):
super(GenNoise, self).__init__()
self.dim2 = dim2
def forward(self, input):
a = list(input.size())
a[1] = self.dim2
# print (input.data.type())
b = torch.zeros(a).type_as(input.data)
b.normal_()
x = torch.autograd.Variable(b)
return x
class Swish(nn.Module):
"""
https://arxiv.org/abs/1710.05941
The hype was so huge that I could not help but try it
"""
def __init__(self):
super(Swish, self).__init__()
self.s = nn.Sigmoid()
def forward(self, x):
return x * self.s(x)
def act(act_fun = 'LeakyReLU'):
'''
Either string defining an activation function or module (e.g. nn.ReLU)
'''
if isinstance(act_fun, str):
if act_fun == 'LeakyReLU':
return nn.LeakyReLU(0.2, inplace=True)
elif act_fun == 'Swish':
return Swish()
elif act_fun == 'ELU':
return nn.ELU()
elif act_fun == 'none':
return nn.Sequential()
else:
assert False
else:
return act_fun()
# nn.BatchNorm2d(num_features) 是一个 PyTorch 中的内置函数。
# 用于构建一个二维卷积层的批归一化(Batch Normalization)模块
# 它将输入数据在通道维度上进行标准化,使得每个通道上的均值为0,方差为1,并对标准化后的数据进行缩放和平移,以便网络能够更快地收敛。
# num_features 参数表示输入数据的通道数 批归一化操作是在通道维度上进行的
def bn(num_features):
return nn.BatchNorm2d(num_features)
def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'):
downsampler = None
if stride != 1 and downsample_mode != 'stride':
if downsample_mode == 'avg':
downsampler = nn.AvgPool2d(stride, stride)
elif downsample_mode == 'max':
downsampler = nn.MaxPool2d(stride, stride)
elif downsample_mode in ['lanczos2', 'lanczos3']:
downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True)
else:
assert False
stride = 1
padder = None
to_pad = int((kernel_size - 1) / 2)
if pad == 'reflection':
padder = nn.ReflectionPad2d(to_pad)
to_pad = 0
convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias)
layers = filter(lambda x: x is not None, [padder, convolver, downsampler])
return nn.Sequential(*layers)
import numpy as np
import torch
import torch.nn as nn
class Downsampler(nn.Module):
'''
http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf
'''
def __init__(self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, preserve_size=False):
super(Downsampler, self).__init__()
assert phase in [0, 0.5], 'phase should be 0 or 0.5'
if kernel_type == 'lanczos2':
support = 2
kernel_width = 4 * factor + 1
kernel_type_ = 'lanczos'
elif kernel_type == 'lanczos3':
support = 3
kernel_width = 6 * factor + 1
kernel_type_ = 'lanczos'
elif kernel_type == 'gauss12':
kernel_width = 7
sigma = 1/2
kernel_type_ = 'gauss'
elif kernel_type == 'gauss1sq2':
kernel_width = 9
sigma = 1./np.sqrt(2)
kernel_type_ = 'gauss'
elif kernel_type in ['lanczos', 'gauss', 'box']:
kernel_type_ = kernel_type
else:
assert False, 'wrong name kernel'
# note that `kernel width` will be different to actual size for phase = 1/2
self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma)
downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0)
downsampler.weight.data[:] = 0
downsampler.bias.data[:] = 0
kernel_torch = torch.from_numpy(self.kernel)
for i in range(n_planes):
downsampler.weight.data[i, i] = kernel_torch
self.downsampler_ = downsampler
if preserve_size:
if self.kernel.shape[0] % 2 == 1:
pad = int((self.kernel.shape[0] - 1) / 2.)
else:
pad = int((self.kernel.shape[0] - factor) / 2.)
self.padding = nn.ReplicationPad2d(pad)
self.preserve_size = preserve_size
def forward(self, input):
if self.preserve_size:
x = self.padding(input)
else:
x= input
self.x = x
return self.downsampler_(x)
def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None):
assert kernel_type in ['lanczos', 'gauss', 'box']
# factor = float(factor)
if phase == 0.5 and kernel_type != 'box':
kernel = np.zeros([kernel_width - 1, kernel_width - 1])
else:
kernel = np.zeros([kernel_width, kernel_width])
if kernel_type == 'box':
assert phase == 0.5, 'Box filter is always half-phased'
kernel[:] = 1./(kernel_width * kernel_width)
elif kernel_type == 'gauss':
assert sigma, 'sigma is not specified'
assert phase != 0.5, 'phase 1/2 for gauss not implemented'
center = (kernel_width + 1.)/2.
print(center, kernel_width)
sigma_sq = sigma * sigma
for i in range(1, kernel.shape[0] + 1):
for j in range(1, kernel.shape[1] + 1):
di = (i - center)/2.
dj = (j - center)/2.
kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj)/(2 * sigma_sq))
kernel[i - 1][j - 1] = kernel[i - 1][j - 1]/(2. * np.pi * sigma_sq)
elif kernel_type == 'lanczos':
assert support, 'support is not specified'
center = (kernel_width + 1) / 2.
for i in range(1, kernel.shape[0] + 1):
for j in range(1, kernel.shape[1] + 1):
if phase == 0.5:
di = abs(i + 0.5 - center) / factor
dj = abs(j + 0.5 - center) / factor
else:
di = abs(i - center) / factor
dj = abs(j - center) / factor
pi_sq = np.pi * np.pi
val = 1
if di != 0:
val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support)
val = val / (np.pi * np.pi * di * di)
if dj != 0:
val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support)
val = val / (np.pi * np.pi * dj * dj)
kernel[i - 1][j - 1] = val
else:
assert False, 'wrong method name'
kernel /= kernel.sum()
return kernel
#a = Downsampler(n_planes=3, factor=2, kernel_type='lanczos2', phase='1', preserve_size=True)
#################
# downsampler
# Learnable downsampler
# KS = 32
# dow = nn.Sequential(nn.ReplicationPad2d(int((KS - factor) / 2.)), nn.Conv2d(1,1,KS,factor))
# class Apply(nn.Module):
# def __init__(self, what, dim, *args):
# super(Apply, self).__init__()
# self.dim = dim
# self.what = what
# def forward(self, input):
# inputs = []
# for i in range(input.size(self.dim)):
# inputs.append(self.what(input.narrow(self.dim, i, 1)))
# return torch.cat(inputs, dim=self.dim)
# def __len__(self):
# return len(self._modules)
# downs = Apply(dow, 1)
# downs.type(dtype)(net_input.type(dtype)).size()