读书笔记:PyTorch生成对抗网络编程


之前就看过塔里克的python神经网络编程入门的神经网络,现在又买到了塔里克的新书PyTorch生成对抗网络编程,所以这个读书笔记打算边看边开始写了,等全写完再发布,会包括一些我自己对书里概念的理解,以及代码实现中的过程和存在的问题,希望大家多多批评指教。当然,这个部分短时间不能全部完成,不过我会一遍一遍地看书和更新代码的。

一、PyTorch和神经网络

1.1PyTorch入门

由于之前的博客里有更详细的入门PyTorch的内容,这里不加赘述,书里的描述非常通俗易懂,整体强调理解PyTorch处理张量的时候求导很轻松,注释里是我自己动手求导计算的过程,结果均与代码一致。

import torch
x=3.5
y=x*x+2
print(x,y)

x=torch.tensor(3.5)
print(x)

y=x+3
print(y)

x=torch.tensor(3.5,requires_grad=True)
print(x)

y=(x-1)*(x-2)*(x-3)
#y=(x**2-3x+2)(x-3)=X**3-6x**2+11x-6
print(y)

y.backward()
x.grad
#3x**2-12x+11=5.75


x=torch.tensor(3.5,requires_grad=True)
y=x*x
z=2*y+3

z.backward()

x.grad
#z=2X**2+3 4x=14

a=torch.tensor(2.0,requires_grad=True)
b=torch.tensor(1.0,requires_grad=True)

x=2*a+3*b
y=5*a*a+3*b*b*b
z=2*x+3*y

z.backward()

a.grad
#z=4a+6b+15a**2+9b**3 4+30a=64

1.2初试PyTorch和神经网络

from datetime import time
import pandas as pd
import matplotlib.pyplot as plt

df=pd.read_csv('./mnist_train.csv')
df.head()

row=13
data=df.iloc[row]

label=data[0]

img=data[1:].values.reshape(28,28)
plt.title("lable="+str(label))
plt.imshow(img,interpolation='none',cmap='Blues')
plt.show()

这个部分载入了数据,观察了数据,和之前的python神经网络使用的是同一套数据——手写数字。

import torch
import torch.nn as nn
import pandas

class Classifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.model=nn.Sequential(nn.Linear(784,200),nn.Sigmoid(),nn.Linear(200,10),nn.Sigmoid())
        #change nn.Sigmoid() to nn.LeakyReLu
        #add nn.LayerNorm(200) to the middle of nn.LeakyReLu and nn.Linear
        
        self.loss_function=nn.MSELoss()
        #self.loss_function=nn.BCELoss()
        self.optimiser=torch.optim.SGD(self.parameters(),lr=0.01)
        #self.optimiser=torch.optim.Adam(self.parameters())

        self.counter=0
        self.progress=[]

        pass

    def forward(self,inputs):
        return self.model(inputs)

    def train(self,inputs,targets):
        outputs=self.forward(inputs)
        loss=self.loss_function(outputs,targets)

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        self.counter+=1
        if (self.counter%10==0):
            self.progress.append(loss.item())
            pass
        if (self.counter%10000==0):
            print("counter=",self.counter)
            pass
    
    def plot_progress(self):
        df=pandas.DataFrame(self.progress,columns=['loss'])
        df.plot(ylim=(0,1.0),figsize=(16,8),alpha=0.1,marker='.',grid=True,yticks=(0,0.25,0.5))
        pass

接着用PyTorch构建了一个只有三层的简单的神经网络,包括:初始化函数、前进函数、训练函数和绘图函数。

from torch.utils.data import Dataset

class MnistDataset(Dataset):

    def __init__(self,csv_file):
        self.data_df=pandas.read_csv(csv_file,header=None)
        pass

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self,index):
        label=self.data_df.iloc[index,0]
        target=torch.zeros((10))
        target[label]=1.0

        image_values=torch.FloatTensor(self.data_df.iloc[index,1:].values)/255.0

        return label,image_values,target
    
    def plot_image(self,index):
        arr=self.data_df.iloc[index,1:].values.reshape(28,28)
        plt.title("label="+str(self.data_df.iloc[index,0]))
        plt.imshow(arr,interpolation='none',cmap='Blues')
        pass

mnist_dataset=MnistDataset('./mnist_train.csv')
mnist_dataset.plot_image(9)

将数据封装进Dataset中。

C=Classifier()
epochs=3
for i in range(epochs):
    print('training epoch',i+1,"of",epochs)
    for label,image_data_tensor,target_tensor in mnist_dataset:
        C.train(image_data_tensor,target_tensor)
        pass
    pass
C.plot_progress()

进行3个轮次的训练。

mnist_test_dataset=MnistDataset('./mnist_test.csv')
record=19
mnist_test_dataset.plot_image(record)
image_data=mnist_test_dataset[record][1]

output=C.forward(image_data)

pandas.DataFrame(output.detach().numpy()).plot(kind='bar',legend=False,ylim=(0,1))

score=0
items=0

for label,image_data_tensor,target_tensor in mnist_test_dataset:
    answer=C.forward(image_data_tensor).detach().numpy()
    if (answer.argmax()==label):
        score+=1
        pass
    items+=1
    pass

print(score,items,score/items)

在测试集上测试和评估。

1.3改良方法

主要从三个方面进行改良:损失函数、优化器、激活函数。代码位置替换见上一小节的注释部分。

self.loss_function=nn.BCEloss()
self.optimiser=torch.optim.Adam(sef.parameters())
self.model=nn.Sequential(nn.Linear(784,200),nn.LeakyReLu(0.02),nn.Linear(200,10),nn.LeakyReLu(0.02))

1.4CUDA基础知识

简单来说就是numpy的运算比python快,放在GPU上可进行并行计算,torch实现将计算放到GPU上比较容易。

二、GAN初步

2.1GAN的概念

data = pd.read_csv(
    'https://labfile.oss.aliyuncs.com/courses/1283/adult.data.csv')
print(data.head())

2.2生成1010格式规律

判别器:

from numpy.lib.type_check import real
import torch
import torch.nn as nn

import pandas
import matplotlib.pyplot as plt
import random

# def generate_real():
#     real_data=torch.FloatTensor([1,0,1,0])
#     return real_data
def generate_real():
    real_data=torch.FloatTensor([random.uniform(0.8,1.0),random.uniform(0.0,0.2),random.uniform(0.8,1.0),random.uniform(0.0,0.2)])
    return real_data

class Discriminitor(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(nn.Linear(4,3),nn.Sigmoid(),nn.Linear(3,1),nn.Sigmoid())
        self.loss_function=nn.MSELoss()
        self.optimiser=torch.optim.SGD(self.parameters(),lr=0.01)
        self.counter=0
        self.progress=[]
        pass

    def forward(self,inputs):
        return self.model(inputs)
    
    def train(self,inputs,targets):
        outputs=self.forward(inputs)
        loss=self.loss_function(outputs,targets)

        self.counter+=1
        if (self.counter%10==0):
            self.progress.append(loss.item())
            pass
        if (self.counter%10000==0):
            print("counter=",self.counter)
            pass

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

    def plot_progress(self):
        df=pandas.DataFrame(self.progress,columns=['loss'])
        df.plot(ylim=(0,1.0),figsize=(16,8),alpha=0.1,marker='.',grid=True,yticks=(0,0.25,0.5))
        pass

def generate_random(size):
    random_data=torch.rand(size)
    return random_data

D=Discriminitor()

for i in range(10000):
    D.train(generate_real(),torch.FloatTensor([1.0]))
    D.train(generate_random(4),torch.FloatTensor([0.0]))
    pass
D.plot_progress()

print(D.forward(generate_real()).item())
print(D.forward(generate_random(4)).item())

首先构建了一个判别器,这里的写法就和之前手写数字的分类模型很类似,毕竟判别器实现的功能也是分类,实现这个判别器的训练后,最后两行代码第一行应该会是一个比较接近1的浮点数,第二个则是一个比较接近于0的浮点数(但我第一次运行产生的是0.5左右的数字)。
生成器:

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(nn.Linear(1,3),nn.Sigmoid(),nn.Linear(3,4),nn.Sigmoid())
        self.optimiser=torch.optim.SGD(self.parameters(),lr=0.01)

        self.counter=0
        self.progress=[]

        pass

    def forward(self,inputs):
        return self.model(inputs)

    def train(self,D,inputs,targets):
        g_output=self.forward(inputs)

        d_output=D.forward(g_output)

        loss=D.loss_function(d_output,targets)
        self.counter+=1
        if (self.counter%10==0):
            self.progress.append(loss.item())
            pass

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass

G=Generator()
G.forward(torch.FloatTensor([0.5]))

这里最后生成的矩阵是一个四维的但形式是随机的,并不是1010,因为我们还没有进行训练。(如:tensor([0.6035, 0.5250, 0.5049, 0.3661]))
训练GAN:

for i in range(10000):

    D.train(generate_real(),torch.FloatTensor([1.0]))

    D.train(G.forward(torch.FloatTensor([0.5])).detach(),torch.FloatTensor([0.0]))

    G.train(D,torch.FloatTensor([0.5]),torch.FloatTensor([1.0]))

    pass

D.plot_progress()
G.plot_progress()
G.forward(torch.FloatTensor([0.5]))

这里经过我峨嵋你训练过的神经网络就会生成一个具有0101格式的张量。(如tensor([0.9791, 0.0173, 0.9794, 0.0200]))
画图展示生成器的情况:

import numpy
image_list=[]
if (i%1000==0):
    image_list.append(G.forward(torch.FloatTensor([0.5])).detach().numpy())
    plt.figure(figsize=(16,8))
    plt.imshow(numpy.array(image_list).T,interpolation='none',cmap='Blues')

2.3生成手写数字

import torch
import torch.nn as nn
from torch.utils.data import Dataset

import pandas, numpy, random
import matplotlib.pyplot as plt


class MnistDataset(Dataset):

    def __init__(self,csv_file):
        self.data_df=pandas.read_csv(csv_file,header=None)
        pass

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self,index):
        label=self.data_df.iloc[index,0]
        target=torch.zeros((10))
        target[label]=1.0

        image_values=torch.FloatTensor(self.data_df.iloc[index,1:].values)/255.0

        return label,image_values,target
    
    def plot_image(self,index):
        arr=self.data_df.iloc[index,1:].values.reshape(28,28)
        plt.title("label="+str(self.data_df.iloc[index,0]))
        plt.imshow(arr,interpolation='none',cmap='Blues')
        pass

mnist_dataset=MnistDataset('E:/GAN/mnist_train.csv')
mnist_dataset.plot_image(9)

首先导入数据,和之前在手写数字的分类器中的部分一样。

class Discriminitor(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(nn.Linear(784,200),nn.Sigmoid(),nn.Linear(200,1),nn.Sigmoid())
        self.loss_function=nn.MSELoss()
        self.optimiser=torch.optim.SGD(self.parameters(),lr=0.01)
        self.counter=0
        self.progress=[]
        pass

    def forward(self,inputs):
        return self.model(inputs)
    
    def train(self,inputs,targets):
        outputs=self.forward(inputs)
        loss=self.loss_function(outputs,targets)

        self.counter+=1
        if (self.counter%10==0):
            self.progress.append(loss.item())
            pass
        if (self.counter%10000==0):
            print("counter=",self.counter)
            pass

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

    def plot_progress(self):
        df=pandas.DataFrame(self.progress,columns=['loss'])
        df.plot(ylim=(0,1.0),figsize=(16,8),alpha=0.1,marker='.',grid=True,yticks=(0,0.25,0.5))
        pass

D=Discriminitor()

接着是构建判别器,这个和之前0101的是基本一样的,只调整了网络中的参数。

def generate_random(size):
    random_data=torch.rand(size)
    return random_data

for label, image_data_tensor, target_tensor in mnist_dataset:
    D.train(image_data_tensor,torch.FloatTensor([1.0]))
    D.train(generate_random(784),torch.FloatTensor([0.0]))
    pass

for i in range(4):
    image_data_tensor=mnist_dataset[random.randint(0,60000)][1]
    print(D.forward(image_data_tensor).item())
    pass

for i in range(4):
    print(D.forward(generate_random(784)).item())
    pass

这里测试了判别器,最后得到的结果前四个数字应该是接近1的如(0.995710015296936
0.9972994923591614
0.9938621520996094
0.997092604637146),后四个数字则是接近0的如(0.005710733123123646
0.006089874543249607
0.004295888356864452
0.0055423928424716)
这说明了判别器的有效性

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(nn.Linear(1,200),nn.Sigmoid(),nn.Linear(200,784),nn.Sigmoid())
        self.optimiser=torch.optim.SGD(self.parameters(),lr=0.01)

        self.counter=0
        self.progress=[]

        pass

    def forward(self,inputs):
        return self.model(inputs)

    def train(self,D,inputs,targets):
        g_output=self.forward(inputs)

        d_output=D.forward(g_output)

        loss=D.loss_function(d_output,targets)
        self.counter+=1
        if (self.counter%10==0):
            self.progress.append(loss.item())
            pass

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass

    def plot_progress(self):
        df=pandas.DataFrame(self.progress,columns=['loss'])
        df.plot(ylim=(0,1.0),figsize=(16,8),alpha=0.1,marker='.',grid=True,yticks=(0,0.25,0.5))
        pass

构建生成器,这里和之前的0101也类似,只有网络结构的参数不同。

G=Generator()
output=G.forward(generate_random(1))
img=output.detach().numpy().reshape(28,28)
plt.imshow(img,interpolation='none',cmap='Blues')

这里检查了生成器输出,该输出张量有784个值。

for label, image_data_tensor, target_tensor in mnist_dataset:
    D.train(image_data_tensor,torch.FloatTensor([1.0]))

    D.train(G.forward(generate_random(1)).detach(),torch.FloatTensor([0.0]))

    G.train(D,generate_random(1),torch.FloatTensor([1.0]))
    pass

训练GAN,这里花的时间会久一点。

f,axarr=plt.subplots(2,3,figsize=(16,8))
for i in range(2):
    for j in range(3):
        output=G.forward(generate_random(1))
        img=output.detach().numpy().reshape(28,28)
        axarr[i,j].imshow(img,interpolation='none',cmap='Blues')
        pass
    pass

不幸的是这里产生的图片的数字都会像一个,原因是模式崩溃,这个概念简单来说就是生成器不能生成多样化的输出,而只是有单一的输出产生,目前对于模式崩溃还没有很好的理论解释。

#改进损失函数
self.loss_function=nn.BCELoss()

#改进激活函数和对数据进行标准化(判别器)
self.model=nn.Sequential(nn.Linear(784,200),nn.LeakyReLU(0.02),nn.LayerNorm(200),nn.Linear(200,1),nn.Sigmoid())
#改进激活函数和对数据进行标准化(生成器)
self.model=nn.Sequential(nn.Linear(1,200),nn.LeakyReLU(0.02),nn.LayerNorm(200),nn.Linear(200,784),nn.Sigmoid())

#改进优化器(判别器和生成器相同)
self.optimiser=torch.optim.Adam(self.parameters(),lr=0.0001)

#给生成器100作为输入
self.model=nn.Sequential(nn.Linear(100,200),nn.LeakyReLU(0.02),nn.LayerNorm(200),nn.Linear(200,784),nn.Sigmoid())

#到这里,都还无法解决模式崩溃的问题

#改进随机种子和训练函数
def generate_random_image(size):
    random_data=torch.rand(size)
    return random_data

def generate_random_seed(size):
    random_data=torch.randn(size)
    return random_data

for label, image_data_tensor, target_tensor in mnist_dataset:
    D.train(image_data_tensor,torch.FloatTensor([1.0]))
    D.train(G.forward(generate_random_seed(100)).detach(),torch.FloatTensor([0.0]))

    G.train(D,generate_random_seed(100),torch.FloatTensor([1.0]))
    pass

#到这里终于解决模式崩溃的苗头

#优化画损失的图片的函数(鉴别器和生成器的相同)
def plot_progress(self):
    df=pandas.DataFrame(self.progress,columns=['loss'])
    df.plot(ylim=(0),figsize=(16,8),alpha=0.1,marker='.',grid=True,yticks=(0,0.25,0.5,1.0,5.0))
    pass

这一部分是改进了生成手写数字的GAN的部分,解决了模式崩溃的问题。接下来的种子实验进一步分析这个模式崩溃的问题。

seed1=generate_random_seed(100)
out1=G.forward(seed1)
img1=out1.detach().numpy().reshape(28,28)
plt.imshow(img1,interpolation='none',cmap='Blues')

seed2=generate_random_seed(100)
out2=G.forward(seed2)
img2=out2.detach().numpy().reshape(28,28)
plt.imshow(img2,interpolation='none',cmap='Blues')

count=0
f,axarr=plt.subplots(2,3,figsize=(16,8))
for i in range(2):
    for j in range(4):
        seed=seed1+(seed2-seed1)/11*count
        output=G.forward(seed)
        img=output.detach().numpy().reshape(28,28)
        axarr[i,j].imshow(img,interpolation='none',cmap='Blues')
        count=count+1
        pass
    pass

seed3=seed1+seed2
out3=G.forward(seed3)
img3=out3.detach().numpy().reshape(28,28)
plt.imshow(img3,interpolation='none',cmap='Blues')

seed4=seed1-seed2
out4=G.forward(seed4)
img4=out4.detach().numpy().reshape(28,28)
plt.imshow(img4,interpolation='none',cmap='Blues')

这里的种子实验表明,生成器种子之间的平滑插值会生成平滑的插值图像。将种子相加似乎与图像特征的加法组合相对应。不过,种子相减所生成的图像并不遵循任何直观的规律。

2.4生成人脸图像


三、卷积GAN和条件式GAN

3.1卷积


3.2条件式

这一小节用的还是手写数字的数据,和手写数字的代码进行修改的,我主要点明修改到的地方。

import torch
import torch.nn as nn
from torch.utils.data import Dataset

import pandas, numpy, random
import matplotlib.pyplot as plt

首先都是加载库,各种import了。

class MnistDataset(Dataset):
    
    def __init__(self, csv_file):
        self.data_df = pandas.read_csv(csv_file, header=None)
        pass
    
    def __len__(self):
        return len(self.data_df)
    
    def __getitem__(self, index):
        # image target (label)
        label = self.data_df.iloc[index,0]
        target = torch.zeros((10))
        target[label] = 1.0
        
        # image data, normalised from 0-255 to 0-1
        image_values = torch.FloatTensor(self.data_df.iloc[index,1:].values) / 255.0
        
        # return label, image data tensor and target tensor
        return label, image_values, target
    
    def plot_image(self, index):
        img = self.data_df.iloc[index,1:].values.reshape(28,28)
        plt.title("label = " + str(self.data_df.iloc[index,0]))
        plt.imshow(img, interpolation='none', cmap='Blues')
        pass
    
    pass
mnist_dataset = MnistDataset('./mnist_train.csv')
mnist_dataset.plot_image(17)

这里加载数据,举例查看的部分也还是和之前一样。

def generate_random_image(size):
    random_data = torch.rand(size)
    return random_data


def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data

# size here must only be an integer
def generate_random_one_hot(size):
    label_tensor = torch.zeros((size))
    random_idx = random.randint(0,size-1)
    label_tensor[random_idx] = 1.0
    return label_tensor

定义了接下来网络要用到的函数,其中不同的是最后一个函数,这个函数是为了把条件式的gan的条件的标签给到网络。

class Discriminator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        
        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(784+10, 200),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(200),

            nn.Linear(200, 1),
            nn.Sigmoid()
        )
        
        # create loss function
        self.loss_function = nn.BCELoss()

        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []

        pass
    
    
    def forward(self, image_tensor, label_tensor):
        # combine seed and label
        inputs = torch.cat((image_tensor, label_tensor))
        return self.model(inputs)
    
    
    def train(self, inputs, label_tensor, targets):
        # calculate the output of the network
        outputs = self.forward(inputs, label_tensor)
        
        # calculate loss
        loss = self.loss_function(outputs, targets)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 10000 == 0):
            print("counter = ", self.counter)
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        pass
    
    pass

判别器相较于之前的变化,首先是网络第一层因为加了条件,其为一个size为10的one-hot张量,input_size变成了784+10;forward函数和train函数的变化都是多了label_tensor这个条件。

%%time
# test discriminator can separate real data from random noise

D = Discriminator()

for label, image_data_tensor, label_tensor in mnist_dataset:
    # real data
    D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))
    # fake data
    D.train(generate_random_image(784), generate_random_one_hot(10), torch.FloatTensor([0.0]))
    pass
    
D.plot_progress()

for i in range(4):
  label, image_data_tensor, label_tensor = mnist_dataset[random.randint(0,60000)]
  print( D.forward( image_data_tensor, label_tensor ).item() )
  pass

for i in range(4):
  print( D.forward( generate_random_image(784), generate_random_one_hot(10) ).item() )
  pass

训练判别器并画出其损失变化图像,举例看看判别器能否做到很好地判别。

class Generator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        
        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(100+10, 200),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(200),

            nn.Linear(200, 784),
            nn.Sigmoid()
        )
        
        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []
        
        pass
    
    
    def forward(self, seed_tensor, label_tensor):        
        # combine seed and label
        inputs = torch.cat((seed_tensor, label_tensor))
        return self.model(inputs)


    def train(self, D, inputs, label_tensor, targets):
        # calculate the output of the network
        g_output = self.forward(inputs, label_tensor)
        
        # pass onto Discriminator
        d_output = D.forward(g_output, label_tensor)
        
        # calculate error
        loss = D.loss_function(d_output, targets)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    def plot_images(self, label):
        label_tensor = torch.zeros((10))
        label_tensor[label] = 1.0
        # plot a 3 column, 2 row array of sample images
        f, axarr = plt.subplots(2,3, figsize=(16,8))
        for i in range(2):
            for j in range(3):
                axarr[i,j].imshow(G.forward(generate_random_seed(100), label_tensor).detach().cpu().numpy().reshape(28,28), interpolation='none', cmap='Blues')
                pass
            pass
        pass
    
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        pass
    
    pass

G = Generator()

output = G.forward(generate_random_seed(100), generate_random_one_hot(10))

img = output.detach().numpy().reshape(28,28)

plt.imshow(img, interpolation='none', cmap='Blues')

构建生成器,不同和判别器一样网络第一层输入加入条件,input_size变成100+10,forward函数和train函数都加入条件label_tensor,随便举个例子看看判别器能否运行成功。

D = Discriminator()
G = Generator()

%%time 

# train Discriminator and Generator

epochs = 12

for epoch in range(epochs):
  print ("epoch = ", epoch + 1)

  # train Discriminator and Generator

  for label, image_data_tensor, label_tensor in mnist_dataset:
    # train discriminator on true
    D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))

    # random 1-hot label for generator
    random_label = generate_random_one_hot(10)
    
    # train discriminator on false
    # use detach() so gradients in G are not calculated
    D.train(G.forward(generate_random_seed(100), random_label).detach(), random_label, torch.FloatTensor([0.0]))
    
    # different random 1-hot label for generator
    random_label = generate_random_one_hot(10)

    # train generator
    G.train(D, generate_random_seed(100), random_label, torch.FloatTensor([1.0]))

    pass
    
  pass

D.plot_progress()
G.plot_progress()

实例化并训练判别器和生成器,训练12个轮次后,分别画出训练过程中判别器和生成器的损失图。

G.plot_images(9)

最后加上条件,看看生成器能否实现生成预设条件的手写数字。

3.3结语



总结

详细的jupyter notobook版的代码和结果书的作者展示在了https://github.com/makeyourownneuralnetwork/gan,以上只是我测试过的部分和要注意的地方的总结,希望能更好地运用gan来完成一些实际的工作,至于还没测完的章节,我暂时还用不到彩色图像相关的内容,所以图片生成的部分还未测试,不过未来我会继续努力测试完并更新的。

  • 2
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值