导入所需库,定义数据
import torch
from torch import nn
定义网络框架
class VGG16(nn.Module):
def __init__(self):
super().__init__()
self.features_=nn.Sequential(
nn.Conv2d(3,64,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(64,64,3,padding=1),nn.ReLU(inplace=True)
,nn.MaxPool2d(2)
,nn.Conv2d(64,128,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(128,128,3,padding=1),nn.ReLU(inplace=True)
,nn.MaxPool2d(2)
,nn.Conv2d(128,256,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(256,256,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(256,256,3,padding=1),nn.ReLU(inplace=True)
,nn.MaxPool2d(2)
,nn.Conv2d(256,512,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(512,512,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(512,512,3,padding=1),nn.ReLU(inplace=True)
,nn.MaxPool2d(2)
,nn.Conv2d(512,512,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(512,512,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(512,512,3,padding=1),nn.ReLU(inplace=True)
,nn.MaxPool2d(2)
)
self.clf_=nn.Sequential(nn.Dropout(0.5)
,nn.Linear(512*7*7,4096),nn.ReLU(inplace=True)
,nn.Dropout(0.5)
,nn.Linear(4096,4096),nn.ReLU(inplace=True)
,nn.Linear(4096,1000),nn.Softmax(dim=1)
)
def forward(self,x):
x=self.features_(x)
x=x.view(-1,512*7*7)
output=self.clf_(x)
return output
注意第一个全连接层的输入参数的大小为前一层的特征图通道数*特征图长*特征图宽,这个值会因为输入图片的大小变化而改变,可通过如下代码计算得到
net=nn.Sequential(
nn.Conv2d(3,64,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(64,64,3,padding=1),nn.ReLU(inplace=True)
,nn.MaxPool2d(2)
,nn.Conv2d(64,128,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(128,128,3,padding=1),nn.ReLU(inplace=True)
,nn.MaxPool2d(2)
,nn.Conv2d(128,256,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(256,256,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(256,256,3,padding=1),nn.ReLU(inplace=True)
,nn.MaxPool2d(2)
,nn.Conv2d(256,512,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(512,512,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(512,512,3,padding=1),nn.ReLU(inplace=True)
,nn.MaxPool2d(2)
,nn.Conv2d(512,512,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(512,512,3,padding=1),nn.ReLU(inplace=True)
,nn.Conv2d(512,512,3,padding=1),nn.ReLU(inplace=True)
,nn.MaxPool2d(2)
)
data=torch.ones(size=(10,3,224,224))
net(data).shape
运行结果如下,可见通道数为512,特征图大小为7*7
实例化模型并通过summary查看模型架构
from torchinfo import summary
vgg=VGG16()
summary(vgg,input_size=(10,3,224,224))