定义
:将对象/实例封装在工厂里面,封装了对象的细节。
功能
:
工厂模式,顾名思义就是我们可以通过一个指定的“工厂”获得需要的“产品”,在设计模式中主要用于抽象对象的创建过程,让用户可以指定自己想要的对象而不必关心对象的实例化过程。这样做的好处是
用户只需通过固定的接口而不是直接去调用类的实例化方法来获得一个对象的实例,隐藏了实例创建过程的复杂度,解耦了生产实例和使用实例的代码,降低了维护的复杂性
。
示例代码:
Unet网络模型搭建
#UNetFactory:创建了一个工厂模式。
class UNetFactory(nn.Module):
"""
本质上就是一个U型的网络,先encode,后decode,中间可能有架bridge。
其中encoder需要输出skip到decode那边做concatenate,使得decode阶段能补充信息。
bridge不能存在下采样和上采样的操作。
"""
#初始化参数
def __init__(self, encoder_blocks, decoder_blocks, bridge=None):
#初始化继承模型的参数
super(UNetFactory, self).__init__()
#初始化子模型的参数
self.encoder = UNetEncoder(encoder_blocks)
self.bridge = bridge
self.decoder = UNetDecoder(decoder_blocks)
#构建前向传播函数,forward函数实现前向传播过程,其输入可以是一个或者多个variable
def forward(self, x):
res = self.encoder(x)
#res[0]:从res列表中取出第一个元素 res[1:]:从列表中从1开始取出剩余的元素
out, skips = res[0], res[1:]
if self.bridge is not None:
out = self.bridge(out)
out = self.decoder(out, skips)
return out
class UNetEncoder(nn.Module):
"""
encoder会有多次下采样,下采样前的feature map要作为skip缓存起来将来送到decoder用。
这里约定,以下采样为界线,将encoder分成多个block,其中第一个block无下采样操作,后面的每个block内都
含有一次下采样操作。
"""
def __init__(self, blocks):
super(UNetEncoder, self).__init__()
assert len(blocks) > 0
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
skips = []
for i in range(len(self.blocks) - 1):
x = self.blocks[i](x)
skips.append(x)
res = [self.blocks[i+1](x)]
res += skips
return res # 只能以这种方式返回多个tensor
class UNetDecoder(nn.Module):
"""
decoder会有多次上采样,每次上采样后,要跟相应的skip做concatenate。
这里约定,以上采样为界线,将decoder分成多个block,其中最后一个block无上采样操作,其他block内
都含有一次上采样。如此一来,除第一个block以外,其他block都先做concatenate。
"""
def __init__(self, blocks):
super(UNetDecoder, self).__init__()
assert len(blocks) > 1
self.blocks = nn.ModuleList(blocks)
def _center_crop(self, skip, x):
"""
skip和x的关于h和w的size,谁比较大,就裁剪谁
"""
_, _, h1, w1 = skip.shape
_, _, h2, w2 = x.shape
ht, wt = min(h1, h2), min(w1, w2)
dh1 = (h1 - ht) // 2 if h1 > ht else 0
dw1 = (w1 - wt) // 2 if w1 > wt else 0
dh2 = (h2 - ht) // 2 if h2 > ht else 0
dw2 = (w2 - wt) // 2 if w2 > wt else 0
#返回经过裁剪以后的图像
return skip[:, :, dh1: (dh1 + ht), dw1: (dw1 + wt)], \
x[:, :, dh2: (dh2 + ht), dw2: (dw2 + wt)]
def forward(self, x, skips, reverse_skips=True):
assert len(skips) == len(self.blocks) - 1
if reverse_skips:
#
skips = skips[::-1]
x = self.blocks[0](x)
for i in range(1, len(self.blocks)):
skip, x = self._center_crop(skips[i-1], x)
x = torch.cat([skip, x], dim=1)
x = self.blocks[i](x)
return x
def unet_convs(in_channels, out_channels, padding=0):
"""
unet论文里出现次数最多的2个conv3x3(non-padding)的结构
"""
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def unet(in_channels, out_channels):
"""
构造跟论文一致的unet网络
https://arxiv.org/abs/1505.04597
"""
# encoder
encoder_blocks = [
# two conv3x3
unet_convs(in_channels, 64),
# max pool 2x2, two conv3x3
nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
unet_convs(64, 128)
),
# max pool 2x2, two conv3x3
nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
unet_convs(128, 256)
),
# max pool 2x2, two conv3x3
nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
unet_convs(256, 512)
),
# max pool 2x2
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
]
# bridge
bridge = nn.Sequential(
# two conv3x3
unet_convs(512, 1024)
)
# decoder
decoder_blocks = [
# up-conv2x2
nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
# two conv3x3, up-conv2x2
nn.Sequential(
unet_convs(1024, 512),
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
),
# two conv3x3, up-conv2x2
nn.Sequential(
unet_convs(512, 256),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
),
# two conv3x3, up-conv2x2
nn.Sequential(
unet_convs(256, 128),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
),
# two conv3x3, conv1x1
nn.Sequential(
unet_convs(128, 64),
nn.Conv2d(64, out_channels, kernel_size=1)
)
]
#直接引用该工厂模式
return UNetFactory(encoder_blocks, decoder_blocks, bridge)
参考链接:
https://www.cnblogs.com/ppap/p/11103324.html