在网络中经常使用编译码与UNet网络,需要图像是符合2的倍数。
由于真实图像中存在不符合输入的情况,因此在网络中使用torch自带的pad函数,在模型输入前补充大小为0的元素填充形状,输出前用索引得到原始输入一样的大小。
pad(input, pad, mode=‘constant’, value=0)
from torch import nn
import torch.nn.functional as f
class pretransition(nn.Module):
def __init__(self, in_channel=3, out_channel=3, mid_channel=64, block_size=32):
super().__init__()
self.block_size = block_size
self.con1 = nn.Sequential(
nn.Conv2d(in_channel, mid_channel, 3, 2, 1),
nn.Conv2d(mid_channel, mid_channel, 3, 2, 1),
nn.Conv2d(mid_channel, mid_channel, 3, 2, 1),
)
self.con2 = nn.Sequential(
nn.ConvTranspose2d(mid_channel, mid_channel, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ConvTranspose2d(mid_channel, mid_channel, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ConvTranspose2d(mid_channel, out_channel, kernel_size=3, stride=2, padding=1, output_padding=1),
)
def unpad_tensor(self, x, pad_width):
min_height = pad_width[3]
min_width = pad_width[1]
if min_width == 0 and min_height == 0:
return x
elif min_width == 0 and min_height != 0:
return x[:, :, 0:-min_height, :]
elif min_width != 0 and min_height == 0:
return x[:, :, :, 0:-min_width]
else:
return x[:, :, 0:-min_height, 0:-min_width]
def pad_tensor(self, x, block_size=32):
b, c, h, w = x.size()
if h % 32 == 0:
min_height = h
else:
min_height = (h // block_size + 1) * block_size
if w % 32 == 0:
min_width = w
else:
min_width = (w // block_size + 1) * block_size
padding = (
0, min_width - w, # 前面填充1个单位,后面填充两个单位,输入的最后一个维度则增加1+2个单位,成为8
0, min_height - h,
0, 0,
0, 0,
)
x = f.pad(x, padding)
return x, padding
def forward(self, x):
x, padding = self.pad_tensor(x, block_size=self.block_size)
x1 = self.con1(x)
x2 = self.con2(x1)
x2 = self.unpad_tensor(x2, padding)
return x2
# 得到模型参数数量与计算时消耗内存大小
def get_modelsize(model, input):
from thop import clever_format
from thop import profile
flops, params = profile(model, inputs=(input,))
flops, params = clever_format([flops, params], "%.3f")
print('Total flops:' + flops + '\t' + 'Total params:' + params)
def model_print():
import torch
test_net = pretransition(block_size=256)
print(test_net)
test_x = (torch.ones(1, 3, 253, 541))
out = test_net(test_x)
print('input: {}'.format(test_x.shape))
print('output: {}'.format(out.shape))
get_modelsize(test_net, test_x)
if __name__ == '__main__':
model_print()