import torch
import torch.nn as nn
import torchvision
class UNetFactory(nn.Module):
"""
本质上就是一个U型的网络,先encode,后decode,中间可能有架bridge。
其中encoder需要输出skip到decode那边做concatenate,使得decode阶段能补充信息。
bridge不能存在下采样和上采样的操作。
"""
def __init__(self, encoder_blocks, decoder_blocks, bridge=None):
super(UNetFactory, self).__init__()
self.encoder = UNetEncoder(encoder_blocks) # 返回List
self.bridge = bridge #
self.decoder = UNetDecoder(decoder_blocks)
def forward(self, x):
res = self.encoder(x)
out, skips = res[0], res[1:] #将encoder输出与skips分开
if self.bridge is not None:
out = self.bridge(out) #
out = self.decoder(out, skips)#
return out
class UNetEncoder(nn.Module):
"""
encoder会有多次下采样,下采样前的feature map要作为skip缓存起来将来送到decoder用。
这里约定,以下采样为界线,将encoder分成多个block,其中第一个block无下采样操作,后面的每个block内都
含有一次下采样操作。
"""
def __init__(self, blocks): #bolcks = encoder_block
super(UNetEncoder, self).__init__()
assert len(blocks) > 0 #len = 5 [0.1.2.3.4]
self.blocks = nn.ModuleList(blocks)
#assert/nn.ModuleList :
'''
断言函数是对表达式布尔值的判断,要求表达式计算值必须为真。可用于自动调试。
如果表达式为假,触发异常;如果表达式为真,不执行任何操作。
深度学习框架中的应用:断言深度学习模型的模块数不为0。
assert len(blocks) > 0
详解PyTorch中的ModuleList和Sequential: https://zhuanlan.zhihu.com/p/75206669
'''
def forward(self, x):
skips = []
for i in range(len(self.blocks) - 1): #range(4) [0.1.2.3]
x = self.blocks[i](x)
skips.append(x)
res = [self.blocks[i+1](x)] #block[4]_output, e:List
res += skips
return res # 只能以这种方式返回多个tensor
class UNetDecoder(nn.Module):
"""
decoder会有多次上采样,每次上采样后,要跟相应的skip做concatenate。
这里约定,以上采样为界线,将decoder分成多个block,其中最后一个block无上采样操作,其他block内
都含有一次上采样。如此一来,除第一个block以外,其他block都先做concatenate。
"""
def __init__(self, blocks):
super(UNetDecoder, self).__init__()
assert len
Unet解析
最新推荐文章于 2023-03-15 22:45:51 发布