数据集为Mnist手写数字数据集
数据集创建和加载
对图片进行矩阵转换同时正则化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])
train_ds = torchvision.datasets.MNIST("mnist", train=True, transform=transform)
dataloader = DataLoader(train_ds,batch_size=32, shuffle=True)
定义生成器
使用长度为100的noise作为输入,
首先通过linear1 将长度扩大为256 * 7 * 7的向量,同时reshape成[-1, 256, 7, 7]
通过反卷积1,缩小channels值=>[-1, 128, 7, 7]
通过反卷积2,缩小channel值=>[-1, 64, 14, 14]
通过反卷积3,缩小channel值=>[-1, 1, 28, 28]
得到28 * 28的向量矩阵,就和数据集图片像素大小一样
每一层之后添加BatchNorm层进行正则化
对于反卷积以及卷积参数的设定可以参考
卷积和反卷积核计算公式
class Generate(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = nn.Linear(100, 256 * 7 * 7 )
self.bn1 = nn.BatchNorm1d(256 * 7 * 7)
# (128, 7, 7)
self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(128)
# (64, 14, 14)
self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
self.bn3 = nn.BatchNorm2d(64)
# (1, 28, 28)
self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1)
def forward(self, x):
x = F.relu(self.linear1(x))
x = self.bn1(x)
x = torch.reshape(x, (-1, 256, 7, 7))
x = F.relu(self.deconv1(x))
x = self.bn2(x)
x = F.relu(self.deconv2(x))
x = self.bn3(x)
x = F.tanh(self.deconv3(x))
return x
定义判别器
目的是生成[1]向量进行判别
输入为[-1, 1, 28, 28]的像素矩阵
首先通过卷积层扩大channels值缩小h,w
卷积层1:=>[64, 13, 13]
卷积层2:=>[128, 6, 6]
最后通过线性层
激活函数为sigmoid函数
同时,加入dropout函数 随机丢弃一部分神经元 防止判别器过强
class Discriminator(nn.Module):
def __init__(self) -> None:
super().__init__()
# [64, 13, 13] (28-3+1)/2 = 13
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2)
# [128, 6, 6]
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
self.bn = nn.BatchNorm2d(128)
self.fc = nn.Linear(128 * 6 * 6,1)
def forward(self, x):
x = F.dropout2d(F.leaky_relu(self.conv1(x)))
x = F.dropout2d(F.leaky_relu(self.conv2(x)))
x = self.bn(x)
fla = nn.Flatten()
x = fla(x)
x = torch.sigmoid(self.fc(x))
return x
创建模型对象,生成器的学习速率比判别器要大,放置判别器过强
采用BCELoss的损失函数,虽然和CrossEntropyLoss都属于交叉熵损失函数,BCE只用于二分类问题,而CrossEntropyLoss可以用于二分类,也可以用于多分类
gen = Generate()
dis = Discriminator()
gen_optim = optim.Adam(gen.parameters(), lr = 3e-4)
dis_optim = optim.Adam(dis.parameters(), lr = 3e-5)
loss = nn.BCELoss()
write = SummaryWriter('DCGAN_log')
test_input = torch.randn(16, 100)
画图函数
def generate_and_save_images(model, epoch, test_input):
predictions = np.squeeze(model(test_input).detach().cpu().numpy())
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow((predictions[i] + 1)/2, cmap='gray')
plt.axis('off')
plt.savefig(f'./train_mnist/image_at_epoch_{format(epoch)}.png')
plt.show()
训练数据集代码
判别器优化
img_dis 是训练集通过分类器得到数据向量[64, 1] 批次大小batch_size设置的是64
通过与1进行损失计算;
z_gen随机噪声先通过生成器生成28 * 28的图片数据
z_dis是z_gen通过dis判别器得到的数据,与0进行损失计算
dis_loss = 0
dis_optim.zero_grad()
img,_ = item
img_dis = dis(img)
img_dis_loss = loss(img_dis, torch.ones_like(img_dis))
img_dis_loss.backward()
z = torch.randn((64, 100))
z_gen = gen(z)
z_dis = dis(z_gen.detach())
z_dis_loss = loss(z_dis, torch.zeros_like(z_dis))
z_dis_loss.backward()
dis_optim.step()
dis_loss = z_dis_loss + img_dis_loss
生成器优化
通过之前z生成的图片,直接与1做损失计算即可
gen_optim.zero_grad()
z_out = dis(z_gen)
z_gen_loss = loss(z_out, torch.ones_like(z_out))
z_gen_loss.backward()
gen_optim.step()
train_step += 1
训练代码,通过tensorboard观察损失变化
for epoch in range(100):
train_step = 0
for item in real_dataload:
dis_loss = 0
dis_optim.zero_grad()
img,_ = item
img_dis = dis(img)
img_dis_loss = loss(img_dis, torch.ones_like(img_dis))
img_dis_loss.backward()
z = torch.randn((64, 100))
z_gen = gen(z)
z_dis = dis(z_gen.detach())
z_dis_loss = loss(z_dis, torch.zeros_like(z_dis))
z_dis_loss.backward()
dis_optim.step()
dis_loss = z_dis_loss + img_dis_loss
gen_optim.zero_grad()
z_out = dis(z_gen)
z_gen_loss = loss(z_out, torch.ones_like(z_out))
z_gen_loss.backward()
gen_optim.step()
train_step += 1
write.add_scalar("分类器", dis_loss, train_step)
write.add_scalar("生产器", z_gen_loss, train_step)
generate_and_save_images(gen, epoch, test_input)
if epoch > 20 and epoch % 20 == 0:
torch.save(gen, f"gen_method_{epoch}.pth")
torch.save(dis, f"dis_method_{epoch}.pth")
write.close()
以下为训练结果:
可以观察到通过卷积网络之后,提取特征效果明显
epoch1:
epoch 100:
epoch200:epoch300:
损失变化: