基于深度强化学习的绘画智能体 代码分析(五)

本文深入剖析了pytorch中autograd.grad()函数的使用,详细解释了该函数的各个参数,如outputs、inputs、grad_outputs等,并结合代码实例进行分析,帮助理解深度学习中的自动求导过程。
摘要由CSDN通过智能技术生成

GIthub源码

  1. wgan.py
import torch
import torch.nn as nn
import numpy as np
from torch.optim import Adam, SGD
from torch import autograd
from torch.autograd import Variable
import torch.nn.functional as F
from torch.autograd import grad as torch_grad
import torch.nn.utils.weight_norm as weightNorm
from utils.util import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dim = 128
LAMBDA = 10 # Gradient penalty lambda hyperparameter

class TReLU(nn.Module):
    def __init__(self):
            super(TReLU, self).__init__() #子类继承了父类的所有属性和方法,父类属性自然会用父类方法来进行初始化
            self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True) #将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面,self.alpha变成了模型的一部分,成为了模型中根据训练可以改动的参数了,让某些变量在学习的过程中不断的修改其值以达到最优化,Tensor可以通过参数 requires_grad=True 创建, 
            self.alpha.data.fill_(0)

    def forward(self, x):
        x = F.relu(x - self.alpha) + self.alpha
        return x

class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()

            self.conv0 = weightNorm(nn.Conv2d(6, 16, 5, 2, 2))
            self.conv1 = weightNorm(nn.Conv2d(16, 32, 5, 2, 2))
            self.conv2 = weightNorm(nn.Conv2d(32, 64, 5, 2, 2))
            self.conv3 = weightNorm(nn.Conv2d(64, 128, 5, 2, 2))
            self.conv4 = weightNorm(nn.Conv2d(128, 1, 1, 1, 0))
            self.relu0 = TReLU()
            self.relu1 = TReLU()
            self.relu2 = TReLU()
            self.relu3 = TReLU()

        def forward(self, x):
            x = self.conv0(x)
            x = self.relu0(x)
            x = self.conv1(x)
            x = self.relu1(x)
            x = self.conv2(x)
            x = self.relu2(x)
            x = self.conv3(x)
            x = self.relu3(x)
            x = self.conv4(x)
            x = x
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值