pytorch加载某层网络的权重
在修改官方网络框架后,往往需要把对应权重重新赋值:
backbone = models.resnet18(weights='DEFAULT')
self.layer0 = nn.Sequential(nn.Conv2d(6, 64, 7, 2, 3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
self.layer1 = backbone.layer1
self.layer2 = backbone.layer2
self.layer3 = backbone.layer3
self.layer4 = backbone.layer4
loaded = model_zoo.load_url(getattr(models.resnet, f'ResNet{18}_Weights').DEFAULT.url)
self.layer0[0].weight.data = torch.cat([loaded['conv1.weight']] * 2, 1) / 2