GAN网络简单应用——MNISTS数据集合

 import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
"对数据做归一化处理"
transform = transforms.Compose([
    transforms.ToTensor(),                    #0-1  channel,high,wide
    transforms.Normalize(0.5,0.5)   #归一化处理,MNIST图片灰度级在0~255,先每个灰度值除以255,像素值范围缩放到 [0, 1] 区间,在output = (input - mean) / std
])
"实例 是指一个类(Class)的具体对象,这个实例具有类定义的属性和方法,可以被操作和使用"
train_ds= torchvision.datasets.MNIST(r'C:\Users\Administrator\Desktop\图像生成代码\GAN网络简单应用\data',train=True,transform=transform,download=False)
dataloader = DataLoader(train_ds,batch_size=64,shuffle=True)

"random_noise[batch_size,100]>>gen_img[batch_size,1,28,28]"
class generator(nn.Module):
    def __init__(self):
        super(generator,self).__init__()
        self.main = nn.Sequential(
                                    nn.Linear(100,256),
                                    nn.ReLU(),
                                    nn.Linear(256,512),
                                    nn.ReLU(),
                                    nn.Linear(512,784),
                                    nn.Tanh()    #[exp(x)-exp(-x)]/[exp(x)+exp(-x)]
        )

    def forward(self,x):
        img = self.main(x)
        img = img.view(-1,1,28,28)
        return img

#"输入为(1,28,28)的图片,输出为二分类的概率值,输出使用sigmoid激活0-1"
#"BCEloss计算交叉熵损失"
#"nn.LeakyReLU  f(x):x>0输出0,如果x<0,输出a*x  a表示一个很小的斜率 ,比如0,001"

"[batch_size,784]>>[batch_size,1], 每一个value属于0~1之间"
class Discriminator(nn.Module):
    def __init__(self):
        # 调用父类(nn.Module)的构造函数
        super(Discriminator,self).__init__()

        self.main = nn.Sequential(
                                nn.Linear(784, 512),
                                nn.LeakyReLU(),
                                nn.Linear(512,256),
                                nn.LeakyReLU(),
                                nn.Linear(256,1),
                                nn.Sigmoid()   #1/[1+exp(-x)]
        )

    def forward(self,x):
        x=x.view(-1,784)
        x= self.main(x)
        return x

"[b,100]>>[b,1,28,28]>>[b,28,28],并且显示图"
def gen_img_plot(model, test_input):
    # 使用生成器模型获取预测结果,并将其转换为NumPy数组
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())

    # 创建一个绘图窗口,大小为(16, 16)
    plt.figure(figsize=(16, 16))

    # 循环遍历每个生成的图像,并在子图中显示
    for i in range(prediction.shape[0]):
        plt.subplot(4, 4, i + 1)
        # 将图像的像素值范围从[-1, 1]转换为[0, 1],并绘制图像
        plt.imshow((prediction[i] + 1) / 2)
        plt.axis('off')  # 关闭坐标轴显示

    # 显示绘图结果
    plt.show()


device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 生成一个在设备上(GPU或CPU)随机生成的输入张量
test_input = torch.randn([16, 100], device=device)


gen = generator().to(device)
dis = Discriminator().to(device)
d_optim = torch.optim.Adam(dis.parameters(),0.0001)
g_optim = torch.optim.Adam(gen.parameters(),0.0001)
loss_fn= torch.nn.BCELoss()
D_loss =[]
G_loss =[]

def loss_show(D_loss,G_loss):
    plt.figure(figsize=(8,8))
    if len(D_loss)==len(G_loss):
        step = len(D_loss)
    else:
        print("Warning: Lengths of D_loss and G_loss are not equal.")
        exit()
    plt.plot(range(0,step),D_loss,label = 'Discriminator Loss',color = 'red')
    plt.plot(range(0,step),G_loss,label = 'Generator Loss',color = 'blue')
    plt.legend(['Discriminator Loss','Generator Loss'])
    plt.xlabel('step',fontsize=14)
    plt.ylabel('loss value',fontsize=14)
    # 在每个点的位置添加文本标签
    for i, (x, y_d,y_g) in enumerate(zip(range(0,step), D_loss,G_loss)):
        plt.text(x, y_d, f'({y_d:.2f})', fontsize=8, color='red', ha='right', va='bottom')
        plt.text(x, y_g, f'({y_g:.3f})', fontsize=8, color='red', ha='right', va='bottom')
    plt.title('Discriminator and Generator Loss Over Steps')
    plt.show()

#循环训练
for epoch in range(20):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)  # 数据集在分批后的批次数目,共有938个批次

    # 遍历数据加载器的每个批次
    for step, (img, _) in enumerate(dataloader):


        img = img.to(device)
        batch_size = img.size(0)
        random_noise = torch.randn([batch_size, 100], device=device)

        # 判别器优化
        d_optim.zero_grad()  # 优化器梯度清零
        real_output = dis(img)  # 判别器输入真实的图片,real_output对真实的图片进行预测结果,real_output[batch_size, 1]
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 判别器在真实图像上的损失,判别器器希望对真实图片的预测为:真
        d_real_loss.backward()

        gen_img = gen(random_noise)  # random_noise[batch_size, 100] >> gen_img[batch_size, 1, 28, 28]
        fake_output = dis(gen_img.detach())  # 判别器输入生成的图片,对生成图片的预测 gen_img[batch_size, 1, 28, 28] >> fake_output[batch_size, 1]
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 判别器在生成图像上的损失,判别器器希望对生成图片的预测为:假
        d_fake_loss.backward()

        d_loss = d_real_loss + d_fake_loss
        d_optim.step()

        # 生成器优化
        g_optim.zero_grad()
        fake_output = dis(gen_img)#判别器经过上面的:d_optim.step()优化参数后,在对生成的图片进行预测,gen_img[batch_size, 1, 28, 28] >> fake_output[batch_size, 1]
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output))  # 生成器的损失,生成器希望生成的图片为:真

        g_loss.backward()
        g_optim.step()

        # 一个epoch中总的loss
        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss

    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss.detach().cpu().numpy())
        G_loss.append(g_epoch_loss.detach().cpu().numpy())
        print(f'Epoch: {epoch},D_loss: {D_loss[epoch].item():.3f},G_loss: {G_loss[epoch].item():.3f}')
        gen_img_plot(gen, test_input)

loss_show(D_loss,G_loss)

 

  • 9
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值