DCGAN网络的基本实现(Mnist数字集)

数据集为Mnist手写数字数据集

数据集创建和加载

对图片进行矩阵转换同时正则化

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])
train_ds = torchvision.datasets.MNIST("mnist", train=True, transform=transform)
dataloader = DataLoader(train_ds,batch_size=32, shuffle=True)

定义生成器

使用长度为100的noise作为输入,
首先通过linear1 将长度扩大为256 * 7 * 7的向量,同时reshape成[-1, 256, 7, 7]
通过反卷积1,缩小channels值=>[-1, 128, 7, 7]
通过反卷积2,缩小channel值=>[-1, 64, 14, 14]
通过反卷积3,缩小channel值=>[-1, 1, 28, 28]
得到28 * 28的向量矩阵,就和数据集图片像素大小一样
每一层之后添加BatchNorm层进行正则化
对于反卷积以及卷积参数的设定可以参考
卷积和反卷积核计算公式

class Generate(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = nn.Linear(100,  256 * 7 * 7 ) 
        self.bn1 = nn.BatchNorm1d(256 * 7 * 7)
        # (128, 7, 7) 
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        # (64, 14, 14) 
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        # (1, 28, 28)
        self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1)
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.bn1(x)
        x = torch.reshape(x, (-1, 256, 7, 7))
        x = F.relu(self.deconv1(x))
        x = self.bn2(x)
        x = F.relu(self.deconv2(x))
        x = self.bn3(x)
        x = F.tanh(self.deconv3(x))
        return x     

定义判别器

目的是生成[1]向量进行判别
输入为[-1, 1, 28, 28]的像素矩阵
首先通过卷积层扩大channels值缩小h,w
卷积层1:=>[64, 13, 13]
卷积层2:=>[128, 6, 6]
最后通过线性层
激活函数为sigmoid函数
同时,加入dropout函数 随机丢弃一部分神经元 防止判别器过强

class Discriminator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        # [64, 13, 13] (28-3+1)/2 = 13
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2)
        # [128, 6, 6]
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
        self.bn = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128 * 6 * 6,1)
    def forward(self, x):
        x = F.dropout2d(F.leaky_relu(self.conv1(x)))
        x = F.dropout2d(F.leaky_relu(self.conv2(x)))
        x = self.bn(x)
        fla = nn.Flatten()
        x = fla(x)
        x = torch.sigmoid(self.fc(x))
        return x

创建模型对象,生成器的学习速率比判别器要大,放置判别器过强
采用BCELoss的损失函数,虽然和CrossEntropyLoss都属于交叉熵损失函数,BCE只用于二分类问题,而CrossEntropyLoss可以用于二分类,也可以用于多分类

gen = Generate()
dis = Discriminator()
gen_optim = optim.Adam(gen.parameters(), lr = 3e-4)
dis_optim = optim.Adam(dis.parameters(), lr = 3e-5)
loss = nn.BCELoss()
write = SummaryWriter('DCGAN_log')
test_input = torch.randn(16, 100)

画图函数

def generate_and_save_images(model, epoch, test_input):
    predictions = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow((predictions[i] + 1)/2, cmap='gray')
        plt.axis('off')
    plt.savefig(f'./train_mnist/image_at_epoch_{format(epoch)}.png')
    plt.show()

训练数据集代码

判别器优化

img_dis 是训练集通过分类器得到数据向量[64, 1] 批次大小batch_size设置的是64
通过与1进行损失计算;
z_gen随机噪声先通过生成器生成28 * 28的图片数据
z_dis是z_gen通过dis判别器得到的数据,与0进行损失计算

dis_loss = 0

        dis_optim.zero_grad()
        img,_ = item
        img_dis = dis(img)
        img_dis_loss = loss(img_dis, torch.ones_like(img_dis))
        img_dis_loss.backward()

        z = torch.randn((64, 100))
        z_gen = gen(z)
        z_dis = dis(z_gen.detach())
        z_dis_loss = loss(z_dis, torch.zeros_like(z_dis))
        z_dis_loss.backward()
        dis_optim.step()

        dis_loss = z_dis_loss + img_dis_loss

生成器优化

通过之前z生成的图片,直接与1做损失计算即可

 gen_optim.zero_grad()
        z_out = dis(z_gen)
        z_gen_loss = loss(z_out, torch.ones_like(z_out))
        z_gen_loss.backward()
        gen_optim.step()
        train_step += 1

训练代码,通过tensorboard观察损失变化

for epoch in range(100):

    train_step = 0

    for item in real_dataload:
        dis_loss = 0

        dis_optim.zero_grad()
        img,_ = item
        img_dis = dis(img)
        img_dis_loss = loss(img_dis, torch.ones_like(img_dis))
        img_dis_loss.backward()

        z = torch.randn((64, 100))
        z_gen = gen(z)
        z_dis = dis(z_gen.detach())
        z_dis_loss = loss(z_dis, torch.zeros_like(z_dis))
        z_dis_loss.backward()
        dis_optim.step()

        dis_loss = z_dis_loss + img_dis_loss

        gen_optim.zero_grad()
        z_out = dis(z_gen)
        z_gen_loss = loss(z_out, torch.ones_like(z_out))
        z_gen_loss.backward()
        gen_optim.step()
        train_step += 1
        write.add_scalar("分类器", dis_loss, train_step)
        write.add_scalar("生产器", z_gen_loss, train_step)
    generate_and_save_images(gen, epoch, test_input)

    if epoch > 20 and epoch % 20 == 0:
        torch.save(gen, f"gen_method_{epoch}.pth")
        torch.save(dis, f"dis_method_{epoch}.pth")
write.close()

以下为训练结果:

可以观察到通过卷积网络之后,提取特征效果明显
epoch1:在这里插入图片描述

epoch 100:在这里插入图片描述

epoch200:在这里插入图片描述epoch300:
在这里插入图片描述损失变化:
在这里插入图片描述

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

鲨鱼狂飙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值