文章代码来自博主Bubbliiiing,特此感谢。
一.网络总览
二. VGG16主干提取网络
文章采用VGG16作为主干提取网络,只会用到两种类型的层,分别是卷积层和最大池化层。
代码对部分层进行了改动,详情可见下方网络结构,可参考下图作为理解
完整代码如下:
1.vggnet
import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
class VGG(nn.Module):
def __init__(self, features, num_classes=1000):
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not