Pytorch实现生成对抗网络GAN(generative_adversarial_network)

# Import necessary packages.
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
# Device configuration.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters.
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'
# Create a directory if not exists.
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
# Image processing.
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0.5, 0.5, 0.5),    # 3 for RGB channels
#                          std=(0.5, 0.5, 0.5))])


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5],    # 1 for greyscale channels
                         std=[0.5])])
# Load MNIST dataset.
mnist = torchvision.datasets.MNIST(root='../../data/',
                                   train=True,
                                   transform=transform,
                                   download=True)
# Data Loader.
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size,
                                          shuffle=True)
# Discriminator.
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())
# Generator.
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())
# Device setting.
D = D.to(device)
G = G.to(device)
# Binary cross entropy and optimizer.
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

[ B C E ( y , y ^ ) = − 1 N ∑ i = 1 N y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] [ BCE(y, \hat{y}) = -\frac{1}{N}\sum_{i=1}^{N} y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) ] [BCE(y,y^)=N1i=1Nyilog(y^i)+(1yi)log(1y^i)]

# Start training.
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)

        # Create the labels which are later used as input for the BCE loss.
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # =============================================== #
        #            Train  the discriminator             #
        # =============================================== #

        # Compute BCE_Loss using the real images where  BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x)).
        # Secone term of the loss is always zero since real_labels == 1.
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        # Compute BCE_Loss uning fake images.
        # First term of the loss is always zero since fake_labels == 0.
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()

        # =============================================== #
        #            Train  the generator                 #
        # =============================================== #
        
        # Compute loss with fake images.
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)

        # We train G to maximize log(D(G(z))) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(outputs, real_labels)

        # Backprop and optimize.
        reset_grad()
        g_loss.backward()
        g_optimizer.step()

        # set an output countor.
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))

    # Save real images.
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images.
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
Epoch [0/200], Step [200/600], d_loss: 0.0585, g_loss: 3.9752, D(x): 0.99, D(G(z)): 0.05
Epoch [0/200], Step [400/600], d_loss: 0.1085, g_loss: 5.8210, D(x): 0.95, D(G(z)): 0.04
Epoch [0/200], Step [600/600], d_loss: 0.0209, g_loss: 5.5525, D(x): 0.99, D(G(z)): 0.01
Epoch [1/200], Step [200/600], d_loss: 0.0568, g_loss: 4.9637, D(x): 0.98, D(G(z)): 0.04
Epoch [1/200], Step [400/600], d_loss: 0.4625, g_loss: 3.9840, D(x): 0.93, D(G(z)): 0.25
Epoch [1/200], Step [600/600], d_loss: 0.5137, g_loss: 3.7713, D(x): 0.88, D(G(z)): 0.17
Epoch [2/200], Step [200/600], d_loss: 0.1986, g_loss: 4.9873, D(x): 0.93, D(G(z)): 0.06
Epoch [2/200], Step [400/600], d_loss: 0.2533, g_loss: 3.5324, D(x): 0.93, D(G(z)): 0.12
Epoch [2/200], Step [600/600], d_loss: 0.5095, g_loss: 2.6030, D(x): 0.88, D(G(z)): 0.26
Epoch [3/200], Step [200/600], d_loss: 0.8931, g_loss: 1.8008, D(x): 0.81, D(G(z)): 0.32
Epoch [3/200], Step [400/600], d_loss: 0.8408, g_loss: 3.1935, D(x): 0.75, D(G(z)): 0.20
Epoch [3/200], Step [600/600], d_loss: 0.7491, g_loss: 3.4022, D(x): 0.85, D(G(z)): 0.30
Epoch [4/200], Step [200/600], d_loss: 0.2172, g_loss: 3.1683, D(x): 0.92, D(G(z)): 0.07
Epoch [4/200], Step [400/600], d_loss: 0.3179, g_loss: 3.2968, D(x): 0.83, D(G(z)): 0.05
Epoch [4/200], Step [600/600], d_loss: 0.2941, g_loss: 3.4464, D(x): 0.89, D(G(z)): 0.06
Epoch [5/200], Step [200/600], d_loss: 0.1504, g_loss: 4.0243, D(x): 0.97, D(G(z)): 0.09
Epoch [5/200], Step [400/600], d_loss: 0.1915, g_loss: 4.0270, D(x): 0.95, D(G(z)): 0.08
Epoch [5/200], Step [600/600], d_loss: 0.1870, g_loss: 4.5529, D(x): 0.93, D(G(z)): 0.04
Epoch [6/200], Step [200/600], d_loss: 0.1215, g_loss: 4.8041, D(x): 0.93, D(G(z)): 0.02
Epoch [6/200], Step [400/600], d_loss: 0.1458, g_loss: 3.7724, D(x): 0.95, D(G(z)): 0.05
Epoch [6/200], Step [600/600], d_loss: 0.1466, g_loss: 5.0891, D(x): 0.97, D(G(z)): 0.05
Epoch [7/200], Step [200/600], d_loss: 0.0646, g_loss: 5.5297, D(x): 0.98, D(G(z)): 0.03
Epoch [7/200], Step [400/600], d_loss: 0.1897, g_loss: 3.9121, D(x): 0.94, D(G(z)): 0.07
Epoch [7/200], Step [600/600], d_loss: 0.1948, g_loss: 3.9607, D(x): 0.93, D(G(z)): 0.04
Epoch [8/200], Step [200/600], d_loss: 0.1386, g_loss: 4.8868, D(x): 0.96, D(G(z)): 0.06
Epoch [8/200], Step [400/600], d_loss: 0.1233, g_loss: 4.8441, D(x): 0.97, D(G(z)): 0.06
Epoch [8/200], Step [600/600], d_loss: 0.1527, g_loss: 6.3575, D(x): 0.96, D(G(z)): 0.03
Epoch [9/200], Step [200/600], d_loss: 0.4402, g_loss: 5.0937, D(x): 0.97, D(G(z)): 0.28
Epoch [9/200], Step [400/600], d_loss: 0.1842, g_loss: 4.7917, D(x): 0.93, D(G(z)): 0.03
Epoch [9/200], Step [600/600], d_loss: 0.1143, g_loss: 4.1180, D(x): 0.97, D(G(z)): 0.06
Epoch [10/200], Step [200/600], d_loss: 0.1018, g_loss: 5.7902, D(x): 0.97, D(G(z)): 0.04
Epoch [10/200], Step [400/600], d_loss: 0.2134, g_loss: 6.8942, D(x): 0.92, D(G(z)): 0.02
Epoch [10/200], Step [600/600], d_loss: 0.1666, g_loss: 5.9212, D(x): 0.97, D(G(z)): 0.03
Epoch [11/200], Step [200/600], d_loss: 0.1897, g_loss: 6.1356, D(x): 0.93, D(G(z)): 0.02
Epoch [11/200], Step [400/600], d_loss: 0.2510, g_loss: 5.1762, D(x): 0.91, D(G(z)): 0.03
Epoch [11/200], Step [600/600], d_loss: 0.2043, g_loss: 4.9043, D(x): 0.92, D(G(z)): 0.04
Epoch [12/200], Step [200/600], d_loss: 0.0971, g_loss: 5.1556, D(x): 0.97, D(G(z)): 0.05
Epoch [12/200], Step [400/600], d_loss: 0.1948, g_loss: 5.0707, D(x): 0.92, D(G(z)): 0.04
Epoch [12/200], Step [600/600], d_loss: 0.2029, g_loss: 3.6790, D(x): 0.93, D(G(z)): 0.07
Epoch [13/200], Step [200/600], d_loss: 0.4181, g_loss: 4.0421, D(x): 0.85, D(G(z)): 0.05
Epoch [13/200], Step [400/600], d_loss: 0.0749, g_loss: 4.5826, D(x): 0.98, D(G(z)): 0.03
Epoch [13/200], Step [600/600], d_loss: 0.1216, g_loss: 6.2095, D(x): 0.96, D(G(z)): 0.02
Epoch [14/200], Step [200/600], d_loss: 0.4418, g_loss: 4.6061, D(x): 0.97, D(G(z)): 0.24
Epoch [14/200], Step [400/600], d_loss: 0.2945, g_loss: 5.8899, D(x): 0.94, D(G(z)): 0.06
Epoch [14/200], Step [600/600], d_loss: 0.0906, g_loss: 8.8300, D(x): 0.97, D(G(z)): 0.03
Epoch [15/200], Step [200/600], d_loss: 0.2474, g_loss: 4.3673, D(x): 0.91, D(G(z)): 0.03
Epoch [15/200], Step [400/600], d_loss: 0.3510, g_loss: 5.0897, D(x): 0.95, D(G(z)): 0.17
Epoch [15/200], Step [600/600], d_loss: 0.4003, g_loss: 5.3677, D(x): 0.95, D(G(z)): 0.18
Epoch [16/200], Step [200/600], d_loss: 0.2543, g_loss: 4.2240, D(x): 0.93, D(G(z)): 0.06
Epoch [16/200], Step [400/600], d_loss: 0.2977, g_loss: 4.9955, D(x): 0.90, D(G(z)): 0.05
Epoch [16/200], Step [600/600], d_loss: 0.2927, g_loss: 6.8045, D(x): 0.86, D(G(z)): 0.01
Epoch [17/200], Step [200/600], d_loss: 0.1382, g_loss: 5.1378, D(x): 0.96, D(G(z)): 0.07
Epoch [17/200], Step [400/600], d_loss: 0.2166, g_loss: 5.3648, D(x): 0.94, D(G(z)): 0.07
Epoch [17/200], Step [600/600], d_loss: 0.3926, g_loss: 3.9812, D(x): 0.91, D(G(z)): 0.14
Epoch [18/200], Step [200/600], d_loss: 0.3653, g_loss: 3.9638, D(x): 0.94, D(G(z)): 0.17
Epoch [18/200], Step [400/600], d_loss: 0.2735, g_loss: 3.0825, D(x): 0.90, D(G(z)): 0.07
Epoch [18/200], Step [600/600], d_loss: 0.1941, g_loss: 4.0487, D(x): 0.93, D(G(z)): 0.06
Epoch [19/200], Step [200/600], d_loss: 0.3090, g_loss: 3.4005, D(x): 0.89, D(G(z)): 0.06
Epoch [19/200], Step [400/600], d_loss: 0.3249, g_loss: 2.3588, D(x): 0.88, D(G(z)): 0.08
Epoch [19/200], Step [600/600], d_loss: 0.7206, g_loss: 3.0939, D(x): 0.85, D(G(z)): 0.25
Epoch [20/200], Step [200/600], d_loss: 0.4898, g_loss: 3.8527, D(x): 0.89, D(G(z)): 0.13
Epoch [20/200], Step [400/600], d_loss: 0.2342, g_loss: 3.7278, D(x): 0.91, D(G(z)): 0.04
Epoch [20/200], Step [600/600], d_loss: 0.3043, g_loss: 3.1545, D(x): 0.88, D(G(z)): 0.05
Epoch [21/200], Step [200/600], d_loss: 0.2280, g_loss: 3.3986, D(x): 0.95, D(G(z)): 0.11
Epoch [21/200], Step [400/600], d_loss: 0.3255, g_loss: 3.9966, D(x): 0.87, D(G(z)): 0.03
Epoch [21/200], Step [600/600], d_loss: 0.2533, g_loss: 4.1965, D(x): 0.94, D(G(z)): 0.11
Epoch [22/200], Step [200/600], d_loss: 0.1942, g_loss: 4.1865, D(x): 0.92, D(G(z)): 0.04
Epoch [22/200], Step [400/600], d_loss: 0.3280, g_loss: 3.6098, D(x): 0.88, D(G(z)): 0.08
Epoch [22/200], Step [600/600], d_loss: 0.3993, g_loss: 4.7892, D(x): 0.93, D(G(z)): 0.18
Epoch [23/200], Step [200/600], d_loss: 0.3928, g_loss: 5.7443, D(x): 0.96, D(G(z)): 0.15
Epoch [23/200], Step [400/600], d_loss: 0.2431, g_loss: 4.3890, D(x): 0.92, D(G(z)): 0.04
Epoch [23/200], Step [600/600], d_loss: 0.3785, g_loss: 4.6090, D(x): 0.91, D(G(z)): 0.10
Epoch [24/200], Step [200/600], d_loss: 0.2587, g_loss: 4.8328, D(x): 0.92, D(G(z)): 0.06
Epoch [24/200], Step [400/600], d_loss: 0.2945, g_loss: 4.9799, D(x): 0.97, D(G(z)): 0.15
Epoch [24/200], Step [600/600], d_loss: 0.3235, g_loss: 2.5292, D(x): 0.91, D(G(z)): 0.06
Epoch [25/200], Step [200/600], d_loss: 0.4326, g_loss: 2.7018, D(x): 0.91, D(G(z)): 0.19
Epoch [25/200], Step [400/600], d_loss: 0.2255, g_loss: 2.8194, D(x): 0.94, D(G(z)): 0.09
Epoch [25/200], Step [600/600], d_loss: 0.3143, g_loss: 4.6071, D(x): 0.87, D(G(z)): 0.04
Epoch [26/200], Step [200/600], d_loss: 0.3394, g_loss: 4.2181, D(x): 0.89, D(G(z)): 0.08
Epoch [26/200], Step [400/600], d_loss: 0.2341, g_loss: 5.4280, D(x): 0.91, D(G(z)): 0.04
Epoch [26/200], Step [600/600], d_loss: 0.4442, g_loss: 3.7772, D(x): 0.88, D(G(z)): 0.09
Epoch [27/200], Step [200/600], d_loss: 0.1985, g_loss: 3.9530, D(x): 0.94, D(G(z)): 0.08
Epoch [27/200], Step [400/600], d_loss: 0.2166, g_loss: 3.7587, D(x): 0.94, D(G(z)): 0.08
Epoch [27/200], Step [600/600], d_loss: 0.4773, g_loss: 2.0019, D(x): 0.88, D(G(z)): 0.19
Epoch [28/200], Step [200/600], d_loss: 0.4933, g_loss: 3.2889, D(x): 0.82, D(G(z)): 0.06
Epoch [28/200], Step [400/600], d_loss: 0.4324, g_loss: 4.1905, D(x): 0.94, D(G(z)): 0.20
Epoch [28/200], Step [600/600], d_loss: 0.3501, g_loss: 4.7053, D(x): 0.91, D(G(z)): 0.13
Epoch [29/200], Step [200/600], d_loss: 0.6074, g_loss: 3.6883, D(x): 0.77, D(G(z)): 0.06
Epoch [29/200], Step [400/600], d_loss: 0.5943, g_loss: 3.1997, D(x): 0.92, D(G(z)): 0.28
Epoch [29/200], Step [600/600], d_loss: 0.6925, g_loss: 3.1762, D(x): 0.90, D(G(z)): 0.30
Epoch [30/200], Step [200/600], d_loss: 0.6968, g_loss: 3.0852, D(x): 0.81, D(G(z)): 0.15
Epoch [30/200], Step [400/600], d_loss: 0.4944, g_loss: 3.4308, D(x): 0.83, D(G(z)): 0.14
Epoch [30/200], Step [600/600], d_loss: 0.3504, g_loss: 2.7500, D(x): 0.90, D(G(z)): 0.16
Epoch [31/200], Step [200/600], d_loss: 0.4152, g_loss: 3.1492, D(x): 0.94, D(G(z)): 0.19
Epoch [31/200], Step [400/600], d_loss: 0.5767, g_loss: 2.9633, D(x): 0.81, D(G(z)): 0.09
Epoch [31/200], Step [600/600], d_loss: 0.4425, g_loss: 2.6821, D(x): 0.88, D(G(z)): 0.18
Epoch [32/200], Step [200/600], d_loss: 0.4573, g_loss: 3.1390, D(x): 0.87, D(G(z)): 0.16
Epoch [32/200], Step [400/600], d_loss: 0.4845, g_loss: 2.7255, D(x): 0.81, D(G(z)): 0.11
Epoch [32/200], Step [600/600], d_loss: 0.5380, g_loss: 3.0490, D(x): 0.91, D(G(z)): 0.26
Epoch [33/200], Step [200/600], d_loss: 0.4952, g_loss: 3.2663, D(x): 0.88, D(G(z)): 0.18
Epoch [33/200], Step [400/600], d_loss: 0.2189, g_loss: 3.7869, D(x): 0.95, D(G(z)): 0.12
Epoch [33/200], Step [600/600], d_loss: 0.4676, g_loss: 4.5630, D(x): 0.89, D(G(z)): 0.19
Epoch [34/200], Step [200/600], d_loss: 0.3048, g_loss: 4.8220, D(x): 0.86, D(G(z)): 0.05
Epoch [34/200], Step [400/600], d_loss: 0.4279, g_loss: 3.7690, D(x): 0.88, D(G(z)): 0.11
Epoch [34/200], Step [600/600], d_loss: 0.4852, g_loss: 3.5232, D(x): 0.86, D(G(z)): 0.13
Epoch [35/200], Step [200/600], d_loss: 0.3201, g_loss: 3.7420, D(x): 0.93, D(G(z)): 0.13
Epoch [35/200], Step [400/600], d_loss: 0.3716, g_loss: 3.9509, D(x): 0.86, D(G(z)): 0.07
Epoch [35/200], Step [600/600], d_loss: 0.2514, g_loss: 3.5438, D(x): 0.92, D(G(z)): 0.10
Epoch [36/200], Step [200/600], d_loss: 0.4011, g_loss: 4.3735, D(x): 0.89, D(G(z)): 0.11
Epoch [36/200], Step [400/600], d_loss: 0.3995, g_loss: 3.4112, D(x): 0.87, D(G(z)): 0.12
Epoch [36/200], Step [600/600], d_loss: 0.5100, g_loss: 3.3964, D(x): 0.80, D(G(z)): 0.09
Epoch [37/200], Step [200/600], d_loss: 0.4825, g_loss: 3.7107, D(x): 0.87, D(G(z)): 0.17
Epoch [37/200], Step [400/600], d_loss: 0.6256, g_loss: 3.1553, D(x): 0.75, D(G(z)): 0.05
Epoch [37/200], Step [600/600], d_loss: 0.3993, g_loss: 3.2951, D(x): 0.88, D(G(z)): 0.15
Epoch [38/200], Step [200/600], d_loss: 0.5285, g_loss: 4.4703, D(x): 0.81, D(G(z)): 0.07
Epoch [38/200], Step [400/600], d_loss: 0.4924, g_loss: 2.9399, D(x): 0.82, D(G(z)): 0.06
Epoch [38/200], Step [600/600], d_loss: 0.6094, g_loss: 2.8946, D(x): 0.76, D(G(z)): 0.10
Epoch [39/200], Step [200/600], d_loss: 0.4856, g_loss: 2.9929, D(x): 0.86, D(G(z)): 0.18
Epoch [39/200], Step [400/600], d_loss: 0.4159, g_loss: 2.6120, D(x): 0.87, D(G(z)): 0.14
Epoch [39/200], Step [600/600], d_loss: 0.4413, g_loss: 3.3187, D(x): 0.87, D(G(z)): 0.14
Epoch [40/200], Step [200/600], d_loss: 0.4556, g_loss: 3.1148, D(x): 0.86, D(G(z)): 0.15
Epoch [40/200], Step [400/600], d_loss: 0.4038, g_loss: 3.3679, D(x): 0.87, D(G(z)): 0.11
Epoch [40/200], Step [600/600], d_loss: 0.3836, g_loss: 2.6056, D(x): 0.92, D(G(z)): 0.19
Epoch [41/200], Step [200/600], d_loss: 0.4729, g_loss: 3.2607, D(x): 0.87, D(G(z)): 0.13
Epoch [41/200], Step [400/600], d_loss: 0.5237, g_loss: 3.3055, D(x): 0.82, D(G(z)): 0.09
Epoch [41/200], Step [600/600], d_loss: 0.4162, g_loss: 2.7583, D(x): 0.86, D(G(z)): 0.15
Epoch [42/200], Step [200/600], d_loss: 0.4786, g_loss: 3.0368, D(x): 0.84, D(G(z)): 0.14
Epoch [42/200], Step [400/600], d_loss: 0.4082, g_loss: 4.2617, D(x): 0.91, D(G(z)): 0.18
Epoch [42/200], Step [600/600], d_loss: 0.4506, g_loss: 3.1848, D(x): 0.84, D(G(z)): 0.08
Epoch [43/200], Step [200/600], d_loss: 0.5603, g_loss: 3.6630, D(x): 0.93, D(G(z)): 0.29
Epoch [43/200], Step [400/600], d_loss: 0.4943, g_loss: 2.9792, D(x): 0.90, D(G(z)): 0.20
Epoch [43/200], Step [600/600], d_loss: 0.4884, g_loss: 2.5247, D(x): 0.88, D(G(z)): 0.19
Epoch [44/200], Step [200/600], d_loss: 0.4935, g_loss: 4.3851, D(x): 0.87, D(G(z)): 0.11
Epoch [44/200], Step [400/600], d_loss: 0.3947, g_loss: 2.9426, D(x): 0.89, D(G(z)): 0.15
Epoch [44/200], Step [600/600], d_loss: 0.3659, g_loss: 2.9597, D(x): 0.87, D(G(z)): 0.11
Epoch [45/200], Step [200/600], d_loss: 0.4466, g_loss: 3.9753, D(x): 0.85, D(G(z)): 0.13
Epoch [45/200], Step [400/600], d_loss: 0.7773, g_loss: 3.0353, D(x): 0.92, D(G(z)): 0.34
Epoch [45/200], Step [600/600], d_loss: 0.3731, g_loss: 3.4807, D(x): 0.88, D(G(z)): 0.13
Epoch [46/200], Step [200/600], d_loss: 0.4586, g_loss: 3.8451, D(x): 0.83, D(G(z)): 0.11
Epoch [46/200], Step [400/600], d_loss: 0.6208, g_loss: 2.3623, D(x): 0.91, D(G(z)): 0.30
Epoch [46/200], Step [600/600], d_loss: 0.4697, g_loss: 3.1523, D(x): 0.84, D(G(z)): 0.11
Epoch [47/200], Step [200/600], d_loss: 0.6875, g_loss: 3.4645, D(x): 0.83, D(G(z)): 0.26
Epoch [47/200], Step [400/600], d_loss: 0.5879, g_loss: 2.5338, D(x): 0.86, D(G(z)): 0.22
Epoch [47/200], Step [600/600], d_loss: 0.7907, g_loss: 3.5667, D(x): 0.69, D(G(z)): 0.07
Epoch [48/200], Step [200/600], d_loss: 0.4863, g_loss: 2.5787, D(x): 0.86, D(G(z)): 0.21
Epoch [48/200], Step [400/600], d_loss: 0.6107, g_loss: 2.3552, D(x): 0.78, D(G(z)): 0.15
Epoch [48/200], Step [600/600], d_loss: 0.5523, g_loss: 2.9557, D(x): 0.78, D(G(z)): 0.09
Epoch [49/200], Step [200/600], d_loss: 0.7314, g_loss: 2.0927, D(x): 0.79, D(G(z)): 0.21
Epoch [49/200], Step [400/600], d_loss: 0.5593, g_loss: 1.9390, D(x): 0.79, D(G(z)): 0.14
Epoch [49/200], Step [600/600], d_loss: 0.4735, g_loss: 2.3364, D(x): 0.86, D(G(z)): 0.20
Epoch [50/200], Step [200/600], d_loss: 0.6261, g_loss: 3.0003, D(x): 0.76, D(G(z)): 0.12
Epoch [50/200], Step [400/600], d_loss: 0.4247, g_loss: 3.4822, D(x): 0.85, D(G(z)): 0.13
Epoch [50/200], Step [600/600], d_loss: 0.7176, g_loss: 2.3449, D(x): 0.77, D(G(z)): 0.19
Epoch [51/200], Step [200/600], d_loss: 0.7946, g_loss: 1.9011, D(x): 0.76, D(G(z)): 0.21
Epoch [51/200], Step [400/600], d_loss: 0.4546, g_loss: 2.7052, D(x): 0.85, D(G(z)): 0.16
Epoch [51/200], Step [600/600], d_loss: 0.7092, g_loss: 2.1256, D(x): 0.83, D(G(z)): 0.28
Epoch [52/200], Step [200/600], d_loss: 0.6237, g_loss: 2.5466, D(x): 0.86, D(G(z)): 0.27
Epoch [52/200], Step [400/600], d_loss: 0.5007, g_loss: 2.4074, D(x): 0.88, D(G(z)): 0.22
Epoch [52/200], Step [600/600], d_loss: 0.8593, g_loss: 3.2508, D(x): 0.71, D(G(z)): 0.13
Epoch [53/200], Step [200/600], d_loss: 0.4494, g_loss: 3.3097, D(x): 0.83, D(G(z)): 0.15
Epoch [53/200], Step [400/600], d_loss: 0.5978, g_loss: 2.7762, D(x): 0.87, D(G(z)): 0.26
Epoch [53/200], Step [600/600], d_loss: 0.5585, g_loss: 2.9809, D(x): 0.79, D(G(z)): 0.16
Epoch [54/200], Step [200/600], d_loss: 0.5217, g_loss: 3.2523, D(x): 0.81, D(G(z)): 0.15
Epoch [54/200], Step [400/600], d_loss: 0.3683, g_loss: 2.5997, D(x): 0.87, D(G(z)): 0.13
Epoch [54/200], Step [600/600], d_loss: 0.4830, g_loss: 3.4485, D(x): 0.83, D(G(z)): 0.14
Epoch [55/200], Step [200/600], d_loss: 0.3236, g_loss: 3.9002, D(x): 0.90, D(G(z)): 0.12
Epoch [55/200], Step [400/600], d_loss: 0.3393, g_loss: 2.8654, D(x): 0.91, D(G(z)): 0.16
Epoch [55/200], Step [600/600], d_loss: 0.2750, g_loss: 4.2370, D(x): 0.90, D(G(z)): 0.11
Epoch [56/200], Step [200/600], d_loss: 0.2641, g_loss: 3.5884, D(x): 0.89, D(G(z)): 0.07
Epoch [56/200], Step [400/600], d_loss: 0.3910, g_loss: 3.6760, D(x): 0.82, D(G(z)): 0.09
Epoch [56/200], Step [600/600], d_loss: 0.4612, g_loss: 3.8326, D(x): 0.87, D(G(z)): 0.18
Epoch [57/200], Step [200/600], d_loss: 0.6390, g_loss: 2.7359, D(x): 0.80, D(G(z)): 0.19
Epoch [57/200], Step [400/600], d_loss: 0.5400, g_loss: 2.5651, D(x): 0.86, D(G(z)): 0.21
Epoch [57/200], Step [600/600], d_loss: 0.6628, g_loss: 2.4713, D(x): 0.76, D(G(z)): 0.12
Epoch [58/200], Step [200/600], d_loss: 0.7499, g_loss: 2.5131, D(x): 0.84, D(G(z)): 0.29
Epoch [58/200], Step [400/600], d_loss: 0.4810, g_loss: 2.3437, D(x): 0.83, D(G(z)): 0.16
Epoch [58/200], Step [600/600], d_loss: 0.5723, g_loss: 2.5647, D(x): 0.83, D(G(z)): 0.20
Epoch [59/200], Step [200/600], d_loss: 0.6727, g_loss: 2.7894, D(x): 0.75, D(G(z)): 0.20
Epoch [59/200], Step [400/600], d_loss: 0.6247, g_loss: 3.5069, D(x): 0.75, D(G(z)): 0.13
Epoch [59/200], Step [600/600], d_loss: 0.6970, g_loss: 1.3711, D(x): 0.82, D(G(z)): 0.28
Epoch [60/200], Step [200/600], d_loss: 0.6095, g_loss: 2.5165, D(x): 0.75, D(G(z)): 0.15
Epoch [60/200], Step [400/600], d_loss: 0.4788, g_loss: 2.4885, D(x): 0.87, D(G(z)): 0.22
Epoch [60/200], Step [600/600], d_loss: 0.6563, g_loss: 2.0455, D(x): 0.80, D(G(z)): 0.21
Epoch [61/200], Step [200/600], d_loss: 0.7438, g_loss: 2.6623, D(x): 0.78, D(G(z)): 0.26
Epoch [61/200], Step [400/600], d_loss: 0.6483, g_loss: 2.1472, D(x): 0.77, D(G(z)): 0.19
Epoch [61/200], Step [600/600], d_loss: 0.5793, g_loss: 2.1237, D(x): 0.77, D(G(z)): 0.14
Epoch [62/200], Step [200/600], d_loss: 0.6211, g_loss: 2.5764, D(x): 0.78, D(G(z)): 0.19
Epoch [62/200], Step [400/600], d_loss: 0.4850, g_loss: 2.0738, D(x): 0.80, D(G(z)): 0.15
Epoch [62/200], Step [600/600], d_loss: 0.6165, g_loss: 2.6696, D(x): 0.81, D(G(z)): 0.25
Epoch [63/200], Step [200/600], d_loss: 0.9170, g_loss: 1.8426, D(x): 0.79, D(G(z)): 0.33
Epoch [63/200], Step [400/600], d_loss: 0.7281, g_loss: 2.8344, D(x): 0.79, D(G(z)): 0.22
Epoch [63/200], Step [600/600], d_loss: 0.6787, g_loss: 2.5681, D(x): 0.74, D(G(z)): 0.15
Epoch [64/200], Step [200/600], d_loss: 0.5552, g_loss: 1.8684, D(x): 0.80, D(G(z)): 0.18
Epoch [64/200], Step [400/600], d_loss: 0.6481, g_loss: 2.5333, D(x): 0.85, D(G(z)): 0.27
Epoch [64/200], Step [600/600], d_loss: 0.8527, g_loss: 2.2508, D(x): 0.69, D(G(z)): 0.16
Epoch [65/200], Step [200/600], d_loss: 0.7515, g_loss: 1.5135, D(x): 0.80, D(G(z)): 0.26
Epoch [65/200], Step [400/600], d_loss: 0.8019, g_loss: 2.2918, D(x): 0.80, D(G(z)): 0.27
Epoch [65/200], Step [600/600], d_loss: 0.6620, g_loss: 2.5662, D(x): 0.74, D(G(z)): 0.14
Epoch [66/200], Step [200/600], d_loss: 0.8183, g_loss: 2.8530, D(x): 0.81, D(G(z)): 0.27
Epoch [66/200], Step [400/600], d_loss: 0.7310, g_loss: 2.8917, D(x): 0.70, D(G(z)): 0.10
Epoch [66/200], Step [600/600], d_loss: 0.6924, g_loss: 2.5030, D(x): 0.76, D(G(z)): 0.22
Epoch [67/200], Step [200/600], d_loss: 0.5773, g_loss: 2.1719, D(x): 0.76, D(G(z)): 0.13
Epoch [67/200], Step [400/600], d_loss: 0.5260, g_loss: 2.1045, D(x): 0.82, D(G(z)): 0.21
Epoch [67/200], Step [600/600], d_loss: 0.5769, g_loss: 2.2418, D(x): 0.79, D(G(z)): 0.16
Epoch [68/200], Step [200/600], d_loss: 0.5832, g_loss: 2.3681, D(x): 0.82, D(G(z)): 0.23
Epoch [68/200], Step [400/600], d_loss: 0.6596, g_loss: 2.6279, D(x): 0.77, D(G(z)): 0.18
Epoch [68/200], Step [600/600], d_loss: 0.7504, g_loss: 2.3642, D(x): 0.82, D(G(z)): 0.30
Epoch [69/200], Step [200/600], d_loss: 0.6209, g_loss: 2.4386, D(x): 0.83, D(G(z)): 0.23
Epoch [69/200], Step [400/600], d_loss: 0.5216, g_loss: 2.4933, D(x): 0.82, D(G(z)): 0.19
Epoch [69/200], Step [600/600], d_loss: 0.7623, g_loss: 2.1849, D(x): 0.81, D(G(z)): 0.28
Epoch [70/200], Step [200/600], d_loss: 0.8145, g_loss: 2.8262, D(x): 0.81, D(G(z)): 0.29
Epoch [70/200], Step [400/600], d_loss: 0.6843, g_loss: 2.6297, D(x): 0.80, D(G(z)): 0.27
Epoch [70/200], Step [600/600], d_loss: 0.6958, g_loss: 1.7765, D(x): 0.79, D(G(z)): 0.27
Epoch [71/200], Step [200/600], d_loss: 0.5670, g_loss: 2.5836, D(x): 0.80, D(G(z)): 0.19
Epoch [71/200], Step [400/600], d_loss: 0.5921, g_loss: 3.5642, D(x): 0.81, D(G(z)): 0.20
Epoch [71/200], Step [600/600], d_loss: 0.7693, g_loss: 3.0274, D(x): 0.82, D(G(z)): 0.30
Epoch [72/200], Step [200/600], d_loss: 0.7627, g_loss: 2.3572, D(x): 0.78, D(G(z)): 0.25
Epoch [72/200], Step [400/600], d_loss: 0.6006, g_loss: 2.6322, D(x): 0.83, D(G(z)): 0.22
Epoch [72/200], Step [600/600], d_loss: 0.7293, g_loss: 1.8255, D(x): 0.77, D(G(z)): 0.22
Epoch [73/200], Step [200/600], d_loss: 0.7579, g_loss: 2.0830, D(x): 0.79, D(G(z)): 0.29
Epoch [73/200], Step [400/600], d_loss: 0.6425, g_loss: 2.0860, D(x): 0.80, D(G(z)): 0.22
Epoch [73/200], Step [600/600], d_loss: 0.7793, g_loss: 1.8010, D(x): 0.80, D(G(z)): 0.31
Epoch [74/200], Step [200/600], d_loss: 1.0133, g_loss: 2.1083, D(x): 0.67, D(G(z)): 0.30
Epoch [74/200], Step [400/600], d_loss: 0.6956, g_loss: 2.0004, D(x): 0.75, D(G(z)): 0.18
Epoch [74/200], Step [600/600], d_loss: 0.5860, g_loss: 2.9672, D(x): 0.77, D(G(z)): 0.17
Epoch [75/200], Step [200/600], d_loss: 0.8044, g_loss: 1.9753, D(x): 0.75, D(G(z)): 0.27
Epoch [75/200], Step [400/600], d_loss: 0.7504, g_loss: 2.0639, D(x): 0.78, D(G(z)): 0.26
Epoch [75/200], Step [600/600], d_loss: 0.9548, g_loss: 2.9962, D(x): 0.71, D(G(z)): 0.26
Epoch [76/200], Step [200/600], d_loss: 0.7734, g_loss: 1.9767, D(x): 0.80, D(G(z)): 0.28
Epoch [76/200], Step [400/600], d_loss: 0.7155, g_loss: 2.1253, D(x): 0.83, D(G(z)): 0.30
Epoch [76/200], Step [600/600], d_loss: 0.6867, g_loss: 1.9904, D(x): 0.80, D(G(z)): 0.25
Epoch [77/200], Step [200/600], d_loss: 0.7512, g_loss: 2.2303, D(x): 0.72, D(G(z)): 0.21
Epoch [77/200], Step [400/600], d_loss: 0.8411, g_loss: 1.3529, D(x): 0.79, D(G(z)): 0.29
Epoch [77/200], Step [600/600], d_loss: 0.9126, g_loss: 2.4320, D(x): 0.67, D(G(z)): 0.24
Epoch [78/200], Step [200/600], d_loss: 0.6438, g_loss: 2.1731, D(x): 0.80, D(G(z)): 0.24
Epoch [78/200], Step [400/600], d_loss: 0.5937, g_loss: 2.1632, D(x): 0.80, D(G(z)): 0.20
Epoch [78/200], Step [600/600], d_loss: 0.8671, g_loss: 2.3467, D(x): 0.82, D(G(z)): 0.35
Epoch [79/200], Step [200/600], d_loss: 0.8624, g_loss: 2.2962, D(x): 0.78, D(G(z)): 0.30
Epoch [79/200], Step [400/600], d_loss: 0.7307, g_loss: 1.7616, D(x): 0.81, D(G(z)): 0.32
Epoch [79/200], Step [600/600], d_loss: 0.9585, g_loss: 2.0750, D(x): 0.70, D(G(z)): 0.27
Epoch [80/200], Step [200/600], d_loss: 0.7243, g_loss: 2.1770, D(x): 0.71, D(G(z)): 0.18



---------------------------------------------------------------------------

KeyboardInterrupt                         Traceback (most recent call last)



KeyboardInterrupt: 

在这里插入图片描述
训练到第80个epoch的时候,生成的图像(右)。

# Save the model checkpoints.
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值