感知损失计算
from torchvision.models import vgg16
import torch
import torch.nn.functional as F
class LossNetwork(torch.nn.Module):
def __init__(self, vgg_model):
super(LossNetwork, self).__init__()
self.vgg_layers = vgg_model
self.layer_name_mapping = {
'3': "relu1_2",
'8': "relu2_2",
'15': "relu3_3"
}
def output_features(self, x):
output = {}
for name, module in self.vgg_layers._modules.items():
x = module(x)
if name in self.layer_name_mapping:
output[self.layer_name_mapping[name]] = x
return list(output.values())
def forward(self, dehaze, gt):
loss = []
dehaze_features = self.output_features(dehaze)
gt_features = self.output_features(gt)
for dehaze_feature, gt_feature in zip(dehaze_features, gt_features):
loss.append(F.mse_loss(dehaze_feature, gt_feature))
return sum(loss)/len(loss)
criterion = []
vgg_model = vgg16(pretrained=True).features[:16]
vgg_model = vgg_model.to(device)
criterion.append(LossNetwork(vgg_model).to(device))
vgg_loss = criterion[2](generate_img,label_img)