import torch
from torch.autograd import Variable
from torch import nn
def vgg_block(num_convs, input_channels, output_channels):
net = [
nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1),
nn.ReLU()
]
for i in range(num_convs - 1):
net.append(nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1))
net.append(nn.ReLU(True))
net.append(nn.MaxPool2d(2,2))
return nn.Sequential(*net)
def vgg_stack(numconvs, channels):
net = []
for n,c in zip(numconvs, channels):
in_c = c[0]
out_c = c[1]
net.append(vgg_block(n, in_c, out_c))
return nn.Sequential(*net)
vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))
print(vgg_net)
print(vgg_net[0])
test_x = Variable(torch.zeros(1, 3, 256, 256))
test_y = vgg_net(test_x)
print(test_y.shape)
构建一个简单的VGG网络
最新推荐文章于 2024-08-18 22:00:00 发布