import torch.nn as nn
a=[]
cfgs = {
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
}
in_channels=3
for key,value in cfgs.items():
print("\nkey:"+key)
print("value"+str(value))
for v in value:
if v=="M":
a+=[nn.MaxPool2d(kernel_size=2,stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding_mode=1)
a += [conv2d]
in_channels = v
b = nn.Sequential(*a)
print(b)
运行结果为
key:vgg11
value[64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
(5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(7): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
(8): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
(11): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
(12): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
Process finished with exit code 0