# coding:utf8
import torch
from torch import nn
vgg = {
11: [(1, 64), (1, 128), (2, 256), (2, 512), (2, 512)],
13: [(2, 64), (2, 128), (2, 256), (2, 512), (2, 512)],
16: [(2, 64), (2, 128), (3, 256), (3, 512), (3, 512)],
19: [(2, 64), (2, 128), (4, 256), (4, 512), (4, 512)],
}
def block(num_conv, in_channels, out_channels):
layer = []
for i in range(num_conv):
layer.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1))
layer.append(nn.ReLU(inplace=True))
in_channels = out_channels
layer.append(nn.MaxPool2d(kernel_size=2, stride=2))
return nn.Sequential(*layer)
def vgg_block(num_layer):
arch = vgg[num_layer]
net = []
in_channels = 3
for num_conv, out_channels in arch:
net.append(block(num_conv, in_channels, out_channels))
in_channels = out_channels
return nn.Sequential(*net, nn.Flatten(),
nn.Linear(out_channels*8*8, 4096),
nn.ReLU(),nn.Dropout(0.5),
nn.Linear(4096, 4096), nn.ReLU(),
nn.Dropout(0.5), nn.Linear(4096, 10))
if __name__ == '__main__':
x = torch.randn(4, 3, 256, 256)
model = vgg_block(19)
for name, layer in model.named_children():
x = layer(x)
print(name, x.shape)
VGG 网络搭建(11、13、16、19层)
最新推荐文章于 2024-03-05 18:44:43 发布