构建一个简单的VGG网络

import torch
from torch.autograd import Variable
from torch import nn

def vgg_block(num_convs, input_channels, output_channels):
    net = [
        nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1),
        nn.ReLU()
    ]
    for i in range(num_convs - 1):
        net.append(nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1))
        net.append(nn.ReLU(True))
    net.append(nn.MaxPool2d(2,2))
    return nn.Sequential(*net)

def vgg_stack(numconvs, channels):
    net = []
    for n,c in zip(numconvs, channels):
        in_c = c[0]
        out_c = c[1]
        net.append(vgg_block(n, in_c, out_c))
    return  nn.Sequential(*net)

vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))
print(vgg_net)
print(vgg_net[0])
test_x = Variable(torch.zeros(1, 3, 256, 256))
test_y = vgg_net(test_x)
print(test_y.shape)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值