很多情形下,我们需要拿在ImageNet
上经过fine-tuning
的ResNet
来作为任务的backbone
,通常我们只会选取预训练好的ResNet
的CNN
部分而非线性层,具体代码操作如下:
import torch
import torch.nn as nn
from torchvision import models
class ResNet152(nn.Module):
def __init__(self, pretrained, path):
super(ResNet152, self).__init__()
self.resnet152 = models.resnet152(pretrained=False)
self.layers = ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4']]
if pretrained == True:
self.resnet152.load_state_dict(torch.load(path))
else:
pass
def forward(self, x):
for name, module in self.resnet152._modules.items():
if name in self.layers:
x = module(x)
return x
这里我是预先在特定path
中下载好了pretrained
的ResNet152
。