pytorch模型输入自动调整大小-pad

在网络中经常使用编译码与UNet网络,需要图像是符合2的倍数。
由于真实图像中存在不符合输入的情况,因此在网络中使用torch自带的pad函数,在模型输入前补充大小为0的元素填充形状,输出前用索引得到原始输入一样的大小。
pad(input, pad, mode=‘constant’, value=0)

from torch import nn
import torch.nn.functional as f


class pretransition(nn.Module):
    def __init__(self, in_channel=3, out_channel=3, mid_channel=64, block_size=32):
        super().__init__()

        self.block_size = block_size
        self.con1 = nn.Sequential(
            nn.Conv2d(in_channel, mid_channel, 3, 2, 1),
            nn.Conv2d(mid_channel, mid_channel, 3, 2, 1),
            nn.Conv2d(mid_channel, mid_channel, 3, 2, 1),
        )

        self.con2 = nn.Sequential(
            nn.ConvTranspose2d(mid_channel, mid_channel, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(mid_channel, mid_channel, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(mid_channel, out_channel, kernel_size=3, stride=2, padding=1, output_padding=1),
        )

    def unpad_tensor(self, x, pad_width):
        min_height = pad_width[3]
        min_width = pad_width[1]

        if min_width == 0 and min_height == 0:
            return x
        elif min_width == 0 and min_height != 0:
            return x[:, :, 0:-min_height, :]
        elif min_width != 0 and min_height == 0:
            return x[:, :, :, 0:-min_width]
        else:
            return x[:, :, 0:-min_height, 0:-min_width]

    def pad_tensor(self, x, block_size=32):
        b, c, h, w = x.size()
        if h % 32 == 0:
            min_height = h
        else:
            min_height = (h // block_size + 1) * block_size
        if w % 32 == 0:
            min_width = w
        else:
            min_width = (w // block_size + 1) * block_size
        padding = (
            0, min_width - w,  # 前面填充1个单位,后面填充两个单位,输入的最后一个维度则增加1+2个单位,成为8
            0, min_height - h,
            0, 0,
            0, 0,
        )
        x = f.pad(x, padding)
        return x, padding

    def forward(self, x):

        x, padding = self.pad_tensor(x, block_size=self.block_size)
        x1 = self.con1(x)
        x2 = self.con2(x1)
        x2 = self.unpad_tensor(x2, padding)

        return x2


# 得到模型参数数量与计算时消耗内存大小
def get_modelsize(model, input):
    from thop import clever_format
    from thop import profile
    flops, params = profile(model, inputs=(input,))
    flops, params = clever_format([flops, params], "%.3f")
    print('Total flops:' + flops + '\t' + 'Total params:' + params)


def model_print():
    import torch
    test_net = pretransition(block_size=256)
    print(test_net)
    test_x = (torch.ones(1, 3, 253, 541))
    out = test_net(test_x)
    print('input: {}'.format(test_x.shape))

    print('output: {}'.format(out.shape))
    get_modelsize(test_net, test_x)


if __name__ == '__main__':
    model_print()

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值