当我们要用到一些通用预训练模型比如 VGG 时,有时候服务器从pytorch官网下载很慢,则可以对代码做如下修改,直接使用本地的模型即可,避免下载
原始代码如下:
import torch
from torchvision import models
import numpy as np
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_model = models.vgg19(pretrained=True)
vgg_pretrained_features = vgg_model.features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
修改后如下:
import torch
from torchvision import models
import numpy as np
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_model = models.vgg19(pretrained=False)
vgg_model.load_state_dict(torch.load('vgg19-dcbb9e9d.pth'))
vgg_pretrained_features = vgg_model.features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()