WGAN的代码实现

本文提供了WGAN的PyTorch实现教程,包括哔哩哔哩视频链接、GitHub代码仓库和详细注释,重点讲解了yield在代码中的作用。
摘要由CSDN通过智能技术生成

1.哔站视频链接:

https://www.bilibili.com/video/BV1TU4y1H7Mz?spm_id_from=333.1007.top_right_bar_window_custom_collection.content.click

2.github链接:

https://github.com/dragen1860/Deep-Learning-with-PyTorch-Tutorials/blob/master/lesson57-WGAN%E5%AE%9E%E6%88%98/wgan_gp.py

3.我的代码注释:


# 说明如果遇见以下错误TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
#  在此文件中File "E:\anaconda\lib\site-packages\torch\_tensor.py", line 678, in __array__(这个文件是你运行后爆红的最后一个错误提示)
#     第678行修改return self.numpy()为return self.cpu().numpy()


import torch
import torch.nn as nn
import numpy as np
import random
import visdom
import matplotlib.pyplot as plt

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
h_dim = 400
batch_size = 512
viz = visdom.Visdom()  # 记得cmd中    python -m visdom.server--》按照地址打开即可显示



# 创建模型结构   生成器  判别器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # z:[b,2]--->[b,2]
            # Linear--ReLU--Linear--ReLU--Linear--ReLU--Linear
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(inplace=True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(inplace=True),
            nn.Linear(h_dim, 2)
        )

    def forward(self, z):
        output = self.net(z)
        return output


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # z:[b,2]--->[b,1]
            # Linear--ReLU--Linear--ReLU--Linear--ReLU--Linear
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(inplace=True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(inplace=True),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()  # 设置范围为【0,1】表示当前输入是真实分布的程度
        )

    def forward(self, z
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值