第G3周:CGAN入门|生成手势图像

基础任务

  1. 条件生成对抗网络(CGAN)的基本原理
  2. CGAN是如何实现条件控制的
  3. 学习本文CGAN代码,并跑通代码

一、前期准备

1、导入库

from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image,make_grid
from torchsummary import summary
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

2、导入数据

batch_size = 128

train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])

train_dataset = datasets.ImageFolder(root="F:/365data/G3/rps/",
                                     transform=train_transform)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=6)

3、数据可视化

def show_images(images):
    fig, ax = plt.subplots(figsize=(20, 20))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))

def show_batch(dl):
    for images, _ in dl:
        show_images(images)
        break
show_batch(train_loader)

二、构建模型

# 自定义权重初始化函数,作用于netG和netD
def weights_init(m):
    # 获取当前层的类名
    classname = m.__class__.__name__
    # 如果类名中包含'Conv',即当前层是卷积层
    if classname.find('Conv') != -1:
        # 使用正态分布初始化权重数据,均值为0,标准差为0.02
        nn.init.normal_(m.weight.data,0.0,0.02)
    # 如果类名中包含'BatchNorm',即当前层是批归一化层
    elif classname.find('BatchNorm') != -1:
        # 使用正态分布初始化权重数据,均值为1,标准差为0.02
        nn.init.normal_(m.weight.data,1.0,0.02)
        # 使用常数初始化偏置项数据,值为0
        nn.init.constant_(m.bias.data,0)

定义生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes,embedding_dim),
            nn.Linear(embedding_dim,16)
        )
        self.latent = nn.Sequential(
            nn.Linear(latent_dim,4*4*512),
            nn.LeakyReLU(0.2,inplace=True)
        )
        self.model = nn.Sequential(
            nn.ConvTranspose2d(513,64*8,4,2,1,bias=False),
            nn.BatchNorm2d(64*8,momentum=0.1,eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8,64*4,4,2,1,bias=False),
            nn.BatchNorm2d(64*4,momentum=0.1,eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4,64*2,4,2,1,bias=False),
            nn.BatchNorm2d(64*2,momentum=0.1,eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2,64,4,2,1,bias=False),
            nn.BatchNorm2d(64,momentum=0.1,eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64,3,4,2,1,bias=False),
            nn.Tanh()
        )
    def forward(self,inputs):
        noise_vector,label = inputs
        label_output = self.label_conditioned_generator(label)
        label_output = label_output.view(-1,1,4,4)
        latent_output = self.latent(noise_vector)
        latent_output = latent_output.view(-1,512,4,4)
        concat = torch.cat((latent_output,label_output),dim=1)
        image = self.model(concat)
        return image

创建生成器

generator = Generator().to(device)
generator.apply(weights_init)
print(generator)
Generator(
  (label_conditioned_generator): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=16, bias=True)
  )
  (latent): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
from torchinfo import summary
summary(generator)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Generator                                --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       1,616
├─Sequential: 1-2                        --
│    └─Linear: 2-3                       827,392
│    └─LeakyReLU: 2-4                    --
├─Sequential: 1-3                        --
│    └─ConvTranspose2d: 2-5              4,202,496
│    └─BatchNorm2d: 2-6                  1,024
│    └─ReLU: 2-7                         --
│    └─ConvTranspose2d: 2-8              2,097,152
│    └─BatchNorm2d: 2-9                  512
│    └─ReLU: 2-10                        --
│    └─ConvTranspose2d: 2-11             524,288
│    └─BatchNorm2d: 2-12                 256
│    └─ReLU: 2-13                        --
│    └─ConvTranspose2d: 2-14             131,072
│    └─BatchNorm2d: 2-15                 128
│    └─ReLU: 2-16                        --
│    └─ConvTranspose2d: 2-17             3,072
│    └─Tanh: 2-18                        --
=================================================================
Total params: 7,789,308
Trainable params: 7,789,308
Non-trainable params: 0
=================================================================

构建鉴别器

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes,embedding_dim),
            nn.Linear(embedding_dim,3*128*128)
        )
        self.model = nn.Sequential(
            nn.Conv2d(6,64,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64,64*2,4,3,2,bias=False),
            nn.BatchNorm2d(64*2,momentum=0.1,eps=0.8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*2,64*4,4,3,2,bias=False),
            nn.BatchNorm2d(64*4,momentum=0.1,eps=0.8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*4,64*8,4,3,2,bias=False),
            nn.BatchNorm2d(64*8,momentum=0.1,eps=0.8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(4608,1),
            nn.Sigmoid()
        )
    def forward(self,inputs):
        img, label = inputs
        label_output = self.label_condition_disc(label)
        label_output = label_output.view(-1,3,128,128)
        concat = torch.cat((img,label_output),dim=1)
        output = self.model(concat)
        return output
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)
Discriminator(
  (label_condition_disc): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=49152, bias=True)
  )
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (6): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (9): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Flatten(start_dim=1, end_dim=-1)
    (12): Dropout(p=0.4, inplace=False)
    (13): Linear(in_features=4608, out_features=1, bias=True)
    (14): Sigmoid()
  )
)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Discriminator                            --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       4,964,352
├─Sequential: 1-2                        --
│    └─Conv2d: 2-3                       6,144
│    └─LeakyReLU: 2-4                    --
│    └─Conv2d: 2-5                       131,072
│    └─BatchNorm2d: 2-6                  256
│    └─LeakyReLU: 2-7                    --
│    └─Conv2d: 2-8                       524,288
│    └─BatchNorm2d: 2-9                  512
│    └─LeakyReLU: 2-10                   --
│    └─Conv2d: 2-11                      2,097,152
│    └─BatchNorm2d: 2-12                 1,024
│    └─LeakyReLU: 2-13                   --
│    └─Flatten: 2-14                     --
│    └─Dropout: 2-15                     --
│    └─Linear: 2-16                      4,609
│    └─Sigmoid: 2-17                     --
=================================================================
Total params: 7,729,709
Trainable params: 7,729,709
Non-trainable params: 0
=================================================================

三、训练模型

定义损失函数

acversarial_loss = nn.BCELoss()

def generator_loss(fake_output,label):
    gen_loss = acversarial_loss(fake_output,label)
    return gen_loss

def discriminator_loss(output,label):
    disc_loss = acversarial_loss(output,label)
    return disc_loss

定义优化器

learning_rate = 0.0002

G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate,betas=(0.5,0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5,0.999))

训练

num_epochs = 300
D_loss_plot, G_loss_plot = [], []

for epoch in range(1,num_epochs+1):
    D_loss_list,G_loss_list = [],[]

    for index,(real_images,labels) in enumerate(train_loader):
        D_optimizer.zero_grad()
        real_images = real_images.to(device)
        labels = labels.to(device)
        labels = labels.unsqueeze(1).long()
        real_target = Variable(torch.ones(real_images.size(0),1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0),1).to(device))

        D_real_loss = discriminator_loss(discriminator((real_images,labels)),real_target)

        noise_vector = torch.randn(real_images.size(0),latent_dim,device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector,labels))

        output = discriminator((generated_image.detach(),labels))
        D_fake_loss = discriminator_loss(output,fake_target)
        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss.item())
        D_total_loss.backward()
        D_optimizer.step()
        G_optimizer.zero_grad()
        G_loss = generator_loss(discriminator((generated_image,labels)),real_target)
        G_loss_list.append(G_loss.item())
        G_loss.backward()
        G_optimizer.step()
    
    print("Epoch: [%d/%d], D_loss: %.3f, G_loss: %.3f" % (
        (epoch),num_epochs,torch.mean(torch.FloatTensor(D_loss_list)),
        torch.mean(torch.FloatTensor(G_loss_list))
    ))

    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))

    if epoch % 10 == 0:
        save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)
        torch.save(generator.state_dict(), './training_weights/generator_epoch%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch%d.pth' % (epoch))
Epoch: [1/300], D_loss: 0.296, G_loss: 1.792
Epoch: [2/300], D_loss: 0.122, G_loss: 3.659
Epoch: [3/300], D_loss: 0.192, G_loss: 3.285
Epoch: [4/300], D_loss: 0.361, G_loss: 2.839
Epoch: [5/300], D_loss: 0.338, G_loss: 2.263
Epoch: [6/300], D_loss: 0.358, G_loss: 2.184
Epoch: [7/300], D_loss: 0.446, G_loss: 2.147
Epoch: [8/300], D_loss: 0.434, G_loss: 1.684
Epoch: [9/300], D_loss: 0.372, G_loss: 1.877
Epoch: [10/300], D_loss: 0.354, G_loss: 1.887
Epoch: [11/300], D_loss: 0.463, G_loss: 2.194
Epoch: [12/300], D_loss: 0.504, G_loss: 2.106
Epoch: [13/300], D_loss: 0.581, G_loss: 1.766
Epoch: [14/300], D_loss: 0.477, G_loss: 1.651
Epoch: [15/300], D_loss: 0.455, G_loss: 1.566
Epoch: [16/300], D_loss: 0.505, G_loss: 1.603
Epoch: [17/300], D_loss: 0.463, G_loss: 1.734
Epoch: [18/300], D_loss: 0.421, G_loss: 1.788
Epoch: [19/300], D_loss: 0.415, G_loss: 1.777
Epoch: [20/300], D_loss: 0.432, G_loss: 1.912
Epoch: [21/300], D_loss: 0.427, G_loss: 1.995
Epoch: [22/300], D_loss: 0.462, G_loss: 1.995
Epoch: [23/300], D_loss: 0.441, G_loss: 1.952
Epoch: [24/300], D_loss: 0.407, G_loss: 1.921
Epoch: [25/300], D_loss: 0.425, G_loss: 1.890
Epoch: [26/300], D_loss: 0.424, G_loss: 1.958
Epoch: [27/300], D_loss: 0.439, G_loss: 1.699
Epoch: [28/300], D_loss: 0.469, G_loss: 1.680
Epoch: [29/300], D_loss: 0.427, G_loss: 1.586
Epoch: [30/300], D_loss: 0.506, G_loss: 1.603
Epoch: [31/300], D_loss: 0.512, G_loss: 1.546
Epoch: [32/300], D_loss: 0.469, G_loss: 1.512
Epoch: [33/300], D_loss: 0.466, G_loss: 1.603
Epoch: [34/300], D_loss: 0.456, G_loss: 1.575
Epoch: [35/300], D_loss: 0.466, G_loss: 1.551
Epoch: [36/300], D_loss: 0.427, G_loss: 1.552
Epoch: [37/300], D_loss: 0.450, G_loss: 1.610
Epoch: [38/300], D_loss: 0.450, G_loss: 1.636
Epoch: [39/300], D_loss: 0.426, G_loss: 1.589
Epoch: [40/300], D_loss: 0.453, G_loss: 1.674
Epoch: [41/300], D_loss: 0.411, G_loss: 1.657
Epoch: [42/300], D_loss: 0.461, G_loss: 1.749
Epoch: [43/300], D_loss: 0.420, G_loss: 1.746
Epoch: [44/300], D_loss: 0.428, G_loss: 1.703
Epoch: [45/300], D_loss: 0.423, G_loss: 1.703
Epoch: [46/300], D_loss: 0.482, G_loss: 1.791
Epoch: [47/300], D_loss: 0.384, G_loss: 1.723
Epoch: [48/300], D_loss: 0.405, G_loss: 1.722
Epoch: [49/300], D_loss: 0.403, G_loss: 1.739
Epoch: [50/300], D_loss: 0.437, G_loss: 1.788
Epoch: [51/300], D_loss: 0.378, G_loss: 1.724
Epoch: [52/300], D_loss: 0.608, G_loss: 2.277
Epoch: [53/300], D_loss: 0.366, G_loss: 1.718
Epoch: [54/300], D_loss: 0.383, G_loss: 1.775
Epoch: [55/300], D_loss: 0.358, G_loss: 1.786
Epoch: [56/300], D_loss: 0.442, G_loss: 1.911
Epoch: [57/300], D_loss: 0.434, G_loss: 1.877
Epoch: [58/300], D_loss: 0.364, G_loss: 1.846
Epoch: [59/300], D_loss: 0.338, G_loss: 1.856
Epoch: [60/300], D_loss: 0.417, G_loss: 1.965
Epoch: [61/300], D_loss: 0.356, G_loss: 1.856
Epoch: [62/300], D_loss: 0.455, G_loss: 2.133
Epoch: [63/300], D_loss: 0.349, G_loss: 1.925
Epoch: [64/300], D_loss: 0.360, G_loss: 2.018
Epoch: [65/300], D_loss: 0.365, G_loss: 2.021
Epoch: [66/300], D_loss: 0.391, G_loss: 2.098
Epoch: [67/300], D_loss: 0.329, G_loss: 2.009
Epoch: [68/300], D_loss: 0.343, G_loss: 2.069
Epoch: [69/300], D_loss: 0.382, G_loss: 2.188
Epoch: [70/300], D_loss: 0.383, G_loss: 2.064
Epoch: [71/300], D_loss: 0.335, G_loss: 2.108
Epoch: [72/300], D_loss: 0.314, G_loss: 2.130
Epoch: [73/300], D_loss: 0.335, G_loss: 2.199
Epoch: [74/300], D_loss: 0.294, G_loss: 2.165
Epoch: [75/300], D_loss: 0.500, G_loss: 2.476
Epoch: [76/300], D_loss: 0.345, G_loss: 2.149
Epoch: [77/300], D_loss: 0.311, G_loss: 2.173
Epoch: [78/300], D_loss: 0.304, G_loss: 2.153
Epoch: [79/300], D_loss: 0.352, G_loss: 2.315
Epoch: [80/300], D_loss: 0.325, G_loss: 2.279
Epoch: [81/300], D_loss: 0.302, G_loss: 2.282
Epoch: [82/300], D_loss: 0.353, G_loss: 2.290
Epoch: [83/300], D_loss: 0.276, G_loss: 2.331
Epoch: [84/300], D_loss: 0.297, G_loss: 2.365
Epoch: [85/300], D_loss: 0.350, G_loss: 2.436
Epoch: [86/300], D_loss: 0.363, G_loss: 2.551
Epoch: [87/300], D_loss: 0.341, G_loss: 2.353
Epoch: [88/300], D_loss: 0.290, G_loss: 2.402
Epoch: [89/300], D_loss: 0.302, G_loss: 2.401
Epoch: [90/300], D_loss: 0.269, G_loss: 2.440
Epoch: [91/300], D_loss: 0.334, G_loss: 2.519
Epoch: [92/300], D_loss: 0.284, G_loss: 2.511
Epoch: [93/300], D_loss: 0.280, G_loss: 2.453
Epoch: [94/300], D_loss: 0.299, G_loss: 2.520
Epoch: [95/300], D_loss: 0.291, G_loss: 2.577
Epoch: [96/300], D_loss: 0.288, G_loss: 2.525
Epoch: [97/300], D_loss: 0.284, G_loss: 2.551
Epoch: [98/300], D_loss: 0.237, G_loss: 2.593
Epoch: [99/300], D_loss: 0.690, G_loss: 3.163
Epoch: [100/300], D_loss: 0.313, G_loss: 2.644
Epoch: [101/300], D_loss: 0.249, G_loss: 2.514
Epoch: [102/300], D_loss: 0.252, G_loss: 2.594
Epoch: [103/300], D_loss: 0.247, G_loss: 2.576
Epoch: [104/300], D_loss: 0.279, G_loss: 2.634
Epoch: [105/300], D_loss: 0.263, G_loss: 2.588
Epoch: [106/300], D_loss: 0.274, G_loss: 2.652
Epoch: [107/300], D_loss: 0.233, G_loss: 2.718
Epoch: [108/300], D_loss: 0.381, G_loss: 2.787
Epoch: [109/300], D_loss: 0.291, G_loss: 2.767
Epoch: [110/300], D_loss: 0.252, G_loss: 2.653
Epoch: [111/300], D_loss: 0.229, G_loss: 2.749
Epoch: [112/300], D_loss: 0.364, G_loss: 2.845
Epoch: [113/300], D_loss: 0.253, G_loss: 2.777
Epoch: [114/300], D_loss: 0.242, G_loss: 2.811
Epoch: [115/300], D_loss: 0.231, G_loss: 2.822
Epoch: [116/300], D_loss: 0.223, G_loss: 2.857
Epoch: [117/300], D_loss: 0.497, G_loss: 3.042
Epoch: [118/300], D_loss: 0.325, G_loss: 2.691
Epoch: [119/300], D_loss: 0.225, G_loss: 2.723
Epoch: [120/300], D_loss: 0.218, G_loss: 2.919
Epoch: [121/300], D_loss: 0.266, G_loss: 2.878
Epoch: [122/300], D_loss: 0.252, G_loss: 2.850
Epoch: [123/300], D_loss: 0.217, G_loss: 2.934
Epoch: [124/300], D_loss: 0.246, G_loss: 2.944
Epoch: [125/300], D_loss: 0.231, G_loss: 2.914
Epoch: [126/300], D_loss: 0.224, G_loss: 2.962
Epoch: [127/300], D_loss: 0.310, G_loss: 3.137
Epoch: [128/300], D_loss: 0.261, G_loss: 2.873
Epoch: [129/300], D_loss: 0.212, G_loss: 2.928
Epoch: [130/300], D_loss: 0.219, G_loss: 3.036
Epoch: [131/300], D_loss: 0.463, G_loss: 3.204
Epoch: [132/300], D_loss: 0.486, G_loss: 2.846
Epoch: [133/300], D_loss: 0.269, G_loss: 2.816
Epoch: [134/300], D_loss: 0.221, G_loss: 2.900
Epoch: [135/300], D_loss: 0.220, G_loss: 3.012
Epoch: [136/300], D_loss: 0.209, G_loss: 2.979
Epoch: [137/300], D_loss: 0.233, G_loss: 3.002
Epoch: [138/300], D_loss: 0.208, G_loss: 3.065
Epoch: [139/300], D_loss: 0.226, G_loss: 3.078
Epoch: [140/300], D_loss: 0.231, G_loss: 3.147
Epoch: [141/300], D_loss: 0.224, G_loss: 3.168
Epoch: [142/300], D_loss: 0.241, G_loss: 3.099
Epoch: [143/300], D_loss: 0.225, G_loss: 3.111
Epoch: [144/300], D_loss: 0.247, G_loss: 3.177
Epoch: [145/300], D_loss: 0.253, G_loss: 3.154
Epoch: [146/300], D_loss: 0.210, G_loss: 3.175
Epoch: [147/300], D_loss: 0.239, G_loss: 3.150
Epoch: [148/300], D_loss: 0.234, G_loss: 3.144
Epoch: [149/300], D_loss: 0.377, G_loss: 3.300
Epoch: [150/300], D_loss: 0.250, G_loss: 3.033
Epoch: [151/300], D_loss: 0.203, G_loss: 3.229
Epoch: [152/300], D_loss: 0.247, G_loss: 3.193
Epoch: [153/300], D_loss: 0.246, G_loss: 3.221
Epoch: [154/300], D_loss: 0.203, G_loss: 3.175
Epoch: [155/300], D_loss: 0.253, G_loss: 3.189
Epoch: [156/300], D_loss: 0.217, G_loss: 3.307
Epoch: [157/300], D_loss: 0.207, G_loss: 3.364
Epoch: [158/300], D_loss: 0.240, G_loss: 3.207
Epoch: [159/300], D_loss: 0.209, G_loss: 3.360
Epoch: [160/300], D_loss: 0.205, G_loss: 3.310
Epoch: [161/300], D_loss: 0.245, G_loss: 3.284
Epoch: [162/300], D_loss: 0.216, G_loss: 3.310
Epoch: [163/300], D_loss: 0.249, G_loss: 3.514
Epoch: [164/300], D_loss: 0.869, G_loss: 2.968
Epoch: [165/300], D_loss: 0.240, G_loss: 2.965
Epoch: [166/300], D_loss: 0.226, G_loss: 3.088
Epoch: [167/300], D_loss: 0.190, G_loss: 3.152
Epoch: [168/300], D_loss: 0.186, G_loss: 3.269
Epoch: [169/300], D_loss: 0.177, G_loss: 3.243
Epoch: [170/300], D_loss: 0.272, G_loss: 3.389
Epoch: [171/300], D_loss: 0.188, G_loss: 3.281
Epoch: [172/300], D_loss: 0.317, G_loss: 3.416
Epoch: [173/300], D_loss: 0.243, G_loss: 3.218
Epoch: [174/300], D_loss: 0.182, G_loss: 3.340
Epoch: [175/300], D_loss: 0.171, G_loss: 3.364
Epoch: [176/300], D_loss: 0.246, G_loss: 3.449
Epoch: [177/300], D_loss: 0.211, G_loss: 3.385
Epoch: [178/300], D_loss: 0.165, G_loss: 3.483
Epoch: [179/300], D_loss: 0.321, G_loss: 3.457
Epoch: [180/300], D_loss: 0.187, G_loss: 3.415
Epoch: [181/300], D_loss: 0.182, G_loss: 3.467
Epoch: [182/300], D_loss: 0.163, G_loss: 3.511
Epoch: [183/300], D_loss: 0.205, G_loss: 3.477
Epoch: [184/300], D_loss: 0.250, G_loss: 3.516
Epoch: [185/300], D_loss: 0.243, G_loss: 3.556
Epoch: [186/300], D_loss: 0.180, G_loss: 3.492
Epoch: [187/300], D_loss: 0.512, G_loss: 3.561
Epoch: [188/300], D_loss: 0.238, G_loss: 3.272
Epoch: [189/300], D_loss: 0.186, G_loss: 3.421
Epoch: [190/300], D_loss: 0.172, G_loss: 3.398
Epoch: [191/300], D_loss: 0.232, G_loss: 3.455
Epoch: [192/300], D_loss: 0.184, G_loss: 3.509
Epoch: [193/300], D_loss: 0.150, G_loss: 3.576
Epoch: [194/300], D_loss: 0.213, G_loss: 3.499
Epoch: [195/300], D_loss: 0.183, G_loss: 3.567
Epoch: [196/300], D_loss: 0.166, G_loss: 3.577
Epoch: [197/300], D_loss: 0.502, G_loss: 3.610
Epoch: [198/300], D_loss: 0.221, G_loss: 3.346
Epoch: [199/300], D_loss: 0.173, G_loss: 3.482
Epoch: [200/300], D_loss: 0.157, G_loss: 3.584
Epoch: [201/300], D_loss: 0.306, G_loss: 3.548
Epoch: [202/300], D_loss: 0.164, G_loss: 3.621
Epoch: [203/300], D_loss: 0.219, G_loss: 3.569
Epoch: [204/300], D_loss: 0.162, G_loss: 3.651
Epoch: [205/300], D_loss: 0.277, G_loss: 3.681
Epoch: [206/300], D_loss: 0.169, G_loss: 3.629
Epoch: [207/300], D_loss: 0.160, G_loss: 3.649
Epoch: [208/300], D_loss: 0.157, G_loss: 3.714
Epoch: [209/300], D_loss: 0.185, G_loss: 3.675
Epoch: [210/300], D_loss: 0.687, G_loss: 3.859
Epoch: [211/300], D_loss: 0.606, G_loss: 2.301
Epoch: [212/300], D_loss: 0.258, G_loss: 3.049
Epoch: [213/300], D_loss: 0.206, G_loss: 3.401
Epoch: [214/300], D_loss: 0.172, G_loss: 3.412
Epoch: [215/300], D_loss: 0.201, G_loss: 3.600
Epoch: [216/300], D_loss: 0.153, G_loss: 3.634
Epoch: [217/300], D_loss: 0.227, G_loss: 3.630
Epoch: [218/300], D_loss: 0.144, G_loss: 3.730
Epoch: [219/300], D_loss: 0.285, G_loss: 3.649
Epoch: [220/300], D_loss: 0.169, G_loss: 3.614
Epoch: [221/300], D_loss: 0.272, G_loss: 3.799
Epoch: [222/300], D_loss: 0.290, G_loss: 3.509
Epoch: [223/300], D_loss: 0.161, G_loss: 3.618
Epoch: [224/300], D_loss: 0.176, G_loss: 3.710
Epoch: [225/300], D_loss: 0.129, G_loss: 3.848
Epoch: [226/300], D_loss: 0.141, G_loss: 3.807
Epoch: [227/300], D_loss: 0.164, G_loss: 3.798
Epoch: [228/300], D_loss: 0.142, G_loss: 3.788
Epoch: [229/300], D_loss: 0.160, G_loss: 3.981
Epoch: [230/300], D_loss: 0.347, G_loss: 3.743
Epoch: [231/300], D_loss: 0.457, G_loss: 3.495
Epoch: [232/300], D_loss: 0.182, G_loss: 3.494
Epoch: [233/300], D_loss: 0.163, G_loss: 3.638
Epoch: [234/300], D_loss: 0.161, G_loss: 3.768
Epoch: [235/300], D_loss: 0.221, G_loss: 3.856
Epoch: [236/300], D_loss: 0.211, G_loss: 3.649
Epoch: [237/300], D_loss: 0.153, G_loss: 3.876
Epoch: [238/300], D_loss: 0.576, G_loss: 3.531
Epoch: [239/300], D_loss: 0.194, G_loss: 3.331
Epoch: [240/300], D_loss: 0.152, G_loss: 3.604
Epoch: [241/300], D_loss: 0.156, G_loss: 3.717
Epoch: [242/300], D_loss: 0.130, G_loss: 3.763
Epoch: [243/300], D_loss: 0.152, G_loss: 3.917
Epoch: [244/300], D_loss: 0.199, G_loss: 3.757
Epoch: [245/300], D_loss: 0.175, G_loss: 3.899
Epoch: [246/300], D_loss: 0.152, G_loss: 3.884
Epoch: [247/300], D_loss: 0.251, G_loss: 3.800
Epoch: [248/300], D_loss: 0.207, G_loss: 3.849
Epoch: [249/300], D_loss: 0.173, G_loss: 3.845
Epoch: [250/300], D_loss: 0.570, G_loss: 3.885
Epoch: [251/300], D_loss: 0.769, G_loss: 2.066
Epoch: [252/300], D_loss: 0.344, G_loss: 2.922
Epoch: [253/300], D_loss: 0.201, G_loss: 3.341
Epoch: [254/300], D_loss: 0.164, G_loss: 3.410
Epoch: [255/300], D_loss: 0.138, G_loss: 3.625
Epoch: [256/300], D_loss: 0.167, G_loss: 3.692
Epoch: [257/300], D_loss: 0.181, G_loss: 3.777
Epoch: [258/300], D_loss: 0.190, G_loss: 3.714
Epoch: [259/300], D_loss: 0.136, G_loss: 3.922
Epoch: [260/300], D_loss: 0.166, G_loss: 3.846
Epoch: [261/300], D_loss: 0.199, G_loss: 3.818
Epoch: [262/300], D_loss: 0.143, G_loss: 3.819
Epoch: [263/300], D_loss: 0.127, G_loss: 4.045
Epoch: [264/300], D_loss: 0.156, G_loss: 3.901
Epoch: [265/300], D_loss: 0.175, G_loss: 4.005
Epoch: [266/300], D_loss: 0.174, G_loss: 3.901
Epoch: [267/300], D_loss: 0.134, G_loss: 3.992
Epoch: [268/300], D_loss: 0.161, G_loss: 4.118
Epoch: [269/300], D_loss: 0.384, G_loss: 3.904
Epoch: [270/300], D_loss: 0.230, G_loss: 3.751
Epoch: [271/300], D_loss: 0.154, G_loss: 3.896
Epoch: [272/300], D_loss: 0.168, G_loss: 3.994
Epoch: [273/300], D_loss: 0.221, G_loss: 3.870
Epoch: [274/300], D_loss: 0.153, G_loss: 4.041
Epoch: [275/300], D_loss: 0.168, G_loss: 3.963
Epoch: [276/300], D_loss: 0.166, G_loss: 3.919
Epoch: [277/300], D_loss: 0.552, G_loss: 3.722
Epoch: [278/300], D_loss: 0.223, G_loss: 3.575
Epoch: [279/300], D_loss: 0.149, G_loss: 3.875
Epoch: [280/300], D_loss: 0.144, G_loss: 3.898
Epoch: [281/300], D_loss: 0.179, G_loss: 3.969
Epoch: [282/300], D_loss: 0.144, G_loss: 3.985
Epoch: [283/300], D_loss: 0.166, G_loss: 4.027
Epoch: [284/300], D_loss: 0.152, G_loss: 4.183
Epoch: [285/300], D_loss: 0.132, G_loss: 4.089
Epoch: [286/300], D_loss: 0.356, G_loss: 3.784
Epoch: [287/300], D_loss: 0.138, G_loss: 3.905
Epoch: [288/300], D_loss: 0.214, G_loss: 3.891
Epoch: [289/300], D_loss: 0.139, G_loss: 4.167
Epoch: [290/300], D_loss: 0.248, G_loss: 3.944
Epoch: [291/300], D_loss: 0.340, G_loss: 3.687
Epoch: [292/300], D_loss: 0.152, G_loss: 3.954
Epoch: [293/300], D_loss: 0.150, G_loss: 4.053
Epoch: [294/300], D_loss: 0.142, G_loss: 4.140
Epoch: [295/300], D_loss: 0.200, G_loss: 4.058
Epoch: [296/300], D_loss: 0.136, G_loss: 4.074
Epoch: [297/300], D_loss: 0.153, G_loss: 4.142
Epoch: [298/300], D_loss: 0.125, G_loss: 4.108
Epoch: [299/300], D_loss: 0.171, G_loss: 4.147
Epoch: [300/300], D_loss: 0.167, G_loss: 4.302

模型分析

generator.load_state_dict(torch.load("D:/桌面/365/training_weights/generator_epoch300.pth"),strict=False)
generator.eval()
from numpy import asarray
from numpy.random import randint
from numpy.random import randn
from numpy import linspace
from matplotlib import pyplot
from matplotlib import gridspec

def generate_latent_points(latent_dim, n_samples, n_classes=3):
    x_input = randn(latent_dim * n_samples)
    z_input = x_input.reshape(n_samples, latent_dim)
    return z_input

def interpolate_points(p1, p2, n_steps=10):
    ratios = linspace(0, 1, num=n_steps)
    vectors = list()
    for ratio in ratios:
        v = (1.0 - ratio) * p1 + ratio * p2
        vectors.append(v)
    return asarray(vectors)

pts = generate_latent_points(100, 2)
interpolated = interpolate_points(pts[0], pts[1])

interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

ouput = None

for label in range(3):
    labels = torch.ones(10) * label
    labels = labels.to(device).unsqueeze(1).long()
    predictions = generator((interpolated,labels))
    predictions = predictions.permute(0,2,3,1).detach().cpu()
    if ouput is None:
        ouput = predictions
    else:
        ouput = torch.cat((ouput,predictions),dim=0)
nrow = 3
ncol = 10
fig = plt.figure(figsize=(15,4))
gs = gridspec.GridSpec(nrow,ncol)

k = 0
for i in range(nrow):
    for j in range(ncol):
        pred = (ouput[k,:,:,:] + 1) * 127.5
        pred = np.array(pred)
        ax = plt.subplot(gs[i,j])
        ax.imshow(pred.astype(np.uint8))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('off')
        k += 1

plt.show()

在这里插入图片描述

总结

  • CGAN与前两周相比,最重要的区别是引入了外部信息,即将无监督学习变成了有监督模型
  • 这样增加了生成图像的真实性
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值