构造任意层数的金字塔(pytorch)

code

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBnReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, dilation=1):
        super(ConvBnReLU, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias, dilation=dilation)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))


class FeatureNet(nn.Module):
    def __init__(self, inner_channels):
        super(FeatureNet, self).__init__()

        self.convs, self.inners, self.outputs = [], [], []

        for level in range(len(inner_channels)):
            if level == 0:
                self.convs.append(
                        ConvBnReLU(3, inner_channels[level], 3, 1, 1),)
                self.convs.append(
                    ConvBnReLU(inner_channels[level], inner_channels[level], 3, 1, 1),)

            else:
                self.convs.append(ConvBnReLU(inner_channels[level-1], inner_channels[level], 5, 2, 2),)
                self.convs.append(ConvBnReLU(inner_channels[level], inner_channels[level], 3, 1, 1),)
                self.convs.append(ConvBnReLU(inner_channels[level], inner_channels[level], 3, 1, 1),)
                self.outputs.append(nn.Conv2d(inner_channels[-1], inner_channels[level], 1, bias=False))

                if level != len(inner_channels)-1:
                    self.inners.append(nn.Conv2d(inner_channels[level], inner_channels[-1], 1, bias=True))

        self.inners.reverse()
        self.outputs.reverse()

        self.convs = nn.Sequential(*self.convs)
        self.inners = nn.Sequential(*self.inners)
        self.outputs = nn.Sequential(*self.outputs)

    def forward(self, x) :

        convx, outx = [], []
        for i, layer in enumerate(self.convs.children()):
            x = layer(x)
            if i == 1 or i%3 ==1:
                convx.append(x)

        convx.reverse()

        inner = convx[0]
        for level, x in enumerate(convx):
            if level == 0:
                outx.append(self.outputs[level](x))
            elif level != len(convx)-1:
                inner = F.interpolate(inner, scale_factor=2.0, mode="bilinear", align_corners=False)\
                            + self.inners[level-1](x)
                outx.append(self.outputs[level](inner))
            else:
                outx.append(x)

        outx.reverse()

        return outx #conv1, f1, f2, f3


if __name__=="__main__":
    x = torch.randint(1,255,(1,3,128,160)).float()
    net1 =  FeatureNet(inner_channels=[8, 16, 32, 64])
    net2 =  FeatureNet(inner_channels=[8, 16, 32, 64, 128])

    # print(net)
    outs1 = net1(x)
    outs2 = net2(x)

    for y in outs1:
        print(y.shape)

    for y in outs2:
        print(y.shape)

result

torch.Size([1, 8, 128, 160])
torch.Size([1, 16, 64, 80])
torch.Size([1, 32, 32, 40])
torch.Size([1, 64, 16, 20])
torch.Size([1, 8, 128, 160])
torch.Size([1, 16, 64, 80])
torch.Size([1, 32, 32, 40])
torch.Size([1, 64, 16, 20])
torch.Size([1, 128, 8, 10])

注:
网络构建好后,测试__init__()中的层是否可以被print出来,如果不能,无法加载到gpu

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值