(二) PyTorch实现perceptual loss

6 篇文章 0 订阅
2 篇文章 0 订阅

另一个版本 ,但是本质时一样的:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from math import exp
import numpy as np
from torchvision import models
import os,cv2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Vgg19_out(nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg19_out, self).__init__()
        vgg = models.vgg19(pretrained=False).to(device) #.cuda()
        vgg.load_state_dict(torch.load('./vgg19-dcbb9e9d.pth'))
        vgg.eval()
        vgg_pretrained_features = vgg.features
        #print(vgg_pretrained_features)
        self.requires_grad = requires_grad
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(4): #(3):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9): #(3, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 14): #(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(14, 23): #(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23, 32):#(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not self.requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out
class Perceptual_loss134(nn.Module):
    def __init__(self):
        super(Perceptual_loss134, self).__init__()
        self.vgg = Vgg19_out().to(device)
        self.criterion = nn.MSELoss()
        #self.weights = [1.0/2.6, 1.0/16, 1.0/3.7, 1.0/5.6, 1.0]
        self.weights = [1.0, 1.0, 1.0, 1.0, 1.0]
    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)

        loss =  self.weights[0]* self.criterion(x_vgg[0], y_vgg[0].detach())+\
               self.weights[2]* self.criterion(x_vgg[2], y_vgg[2].detach())+\
               self.weights[3]*self.criterion(x_vgg[3], y_vgg[3].detach())
        return loss
class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.vgg = Vgg19_out().to(device)
        self.criterion = nn.MSELoss()
        #self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
        self.weights = [1.0, 1.0, 1.0, 1.0, 1.0]
        self.downsample = nn.AvgPool2d(2, stride=2, count_include_pad=False)

    def forward(self, x, y):
        while x.size()[3] > 4096:
            x, y = self.downsample(x), self.downsample(y)
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0.0
        for iter,(x_fea,y_fea) in enumerate(zip(x_vgg,y_vgg)):

            print(iter+1,self.criterion(x_fea, y_fea.detach()),x_fea.size())
            loss +=  self.criterion(x_fea, y_fea.detach())
        return loss

if __name__ == "__main__":
    fea_save_path = "./feature_save/"
    if not os.path.exists(fea_save_path):
        os.mkdir(fea_save_path)
    img1= np.array(cv2.imread("./rain_pair/3in.png"))/255.0
    img2 = np.array(cv2.imread("./rain_pair/3gt.png"))/255.0
    img1 = img1.transpose((2,0, 1))
    img2 = img2.transpose((2,0, 1))
    print(img1.shape,img2.shape)
    img1_torch = torch.unsqueeze(torch.from_numpy(img1),0)
    img2_torch = torch.unsqueeze(torch.from_numpy(img2),0)
    img1_torch = torch.as_tensor(img1_torch, dtype=torch.float32)
    img2_torch = torch.as_tensor(img2_torch, dtype=torch.float32)

    vgg_fea= Vgg19_out()
    img1_vggFea = vgg_fea(img1_torch)
    print(len(img1_vggFea),img1_vggFea[0].shape)

    total_perceptual_loss = VGGLoss()
    perceptual_loss134 = Perceptual_loss134()
    loss1  =total_perceptual_loss(img1_torch,img2_torch)
    loss2 = perceptual_loss134(img1_torch,img2_torch)
    print(loss1,loss2)

 

  • 6
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 10
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值