载入torchvision提供的预训练参数,替换最后一个全连接层
class ResNet_pretrained(nn.Module):
def __init__(self, num_classe):
super(ResNet_pretrained, self).__init__()
self.backbone = torchvision.models.resnet50(pretrained=True)
self.backbone.fc = torch.nn.Linear(2048, 10)
def forward(self, x):
return self.backbone(x)