1.引入VGG content loss来约束图片的内容差异
class VGG19_PercepLoss(nn.Module):
""" Calculates perceptual loss in vgg19 space
"""
def __init__(self, _pretrained_=True):
super(VGG19_PercepLoss, self).__init__()
self.vgg = models.vgg19(pretrained=_pretrained_).features.cuda()
for param in self.vgg.parameters():
param.requires_grad_(False)
def get_features(self, image, layers=None):
if layers is None:
layers = {'30': 'conv5_2'} # may add other layers
features = {}
x = image
for name, layer in self.vgg._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def forward(self, pred, true, layer='conv5_2'):
true_f = self.get_features(true)
pred_f = self.get_features(pred)
return torch.mean((true_f[layer]-pred_f[layer])**2)
内容代码报错!!!
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the sameRuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
解决办法:
input tensor的内容和weight 不匹配
都用cuda()进行一个匹配!!!
self.vgg = models.vgg19(pretrained=_pretrained_).features.cuda()