import torchvision.models as models
PyTorch源码解读之torchvision.models
vgg19 制作vgg loss 代码解读
class Vgg19(nn.Module):
def __init__(self, args, requires_grad=False):
super(Vgg19, self).__init__()
self.args = args
self.vgg_pretrained_features = models.vgg19(pretrained=True).features
#pretrained是true,要导入预训练模型
print(self.vgg_pretrained_features)
if not requires_grad:
for param i