1.哔站视频链接:
https://www.bilibili.com/video/BV1TU4y1H7Mz?spm_id_from=333.1007.top_right_bar_window_custom_collection.content.click
2.github链接:
https://github.com/dragen1860/Deep-Learning-with-PyTorch-Tutorials/blob/master/lesson57-WGAN%E5%AE%9E%E6%88%98/wgan_gp.py
3.我的代码注释:
import torch
import torch.nn as nn
import numpy as np
import random
import visdom
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
h_dim = 400
batch_size = 512
viz = visdom.Visdom()
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(inplace=True),
nn.Linear(h_dim, h_dim),
nn.ReLU(inplace=True),
nn.Linear(h_dim, 2)
)
def forward(self, z):
output = self.net(z)
return output
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(inplace=True),
nn.Linear(h_dim, h_dim),
nn.ReLU(inplace=True),
nn.Linear(h_dim, 1),
nn.Sigmoid()
)
def forward(self, z