一、需要注意的几点:
1、生成器的网络和判别器的网络均不含池化层。
2、判别器的最后一层网络输出使用sigmoid激活,生成器的最后一层网络输出使用tanh激活。
3、生成器和判别器的网络结果呈对称形式如:生成器的第一层的卷积核大小,步长,输入通道,输出通道核判别器的最后一层卷积核大小,步长一致,输出通道,输入通道大小一致。
(上图所示的是生成器,判别器的网络刚好对称,从后往前)
4、卷积核使用偶数大小的效果比使用奇数大小的卷积核效果好。
5、使用转置卷积进行上采样。
6、训练是可以每训练两轮生成器训练一次判别器(原因是判别器能力优于生成 器)。
二、代码部分:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from torchvision.utils import save_image
import numpy as np
import os
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
class Sampling_data(Dataset):
def __init__(self,img_path):
self.file_names = []
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])
])
for file in os.listdir(img_path):
file_name = os.path.join(img_path,file)
self.file_names.append(file_name)
def __len__(self):
return len(self.file_names)
def __getitem__(self, item):
file = self.file_names[item]
img_array = Image.open(file)
xs = self.transform(img_array)
return xs
class Dnet(nn.Module):
def __init__(self):
super(Dnet, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5,stride=3,padding=1,bias=False),
# nn.BatchNorm2d(64)
nn.LeakyReLU(0.2, inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2,padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv4 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv5 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
y = self.conv1(x)
y = self.conv2(y)
y = self.conv3(y)
y = self.conv4(y)
y = self.conv5(y)
return y
class Gnet(nn.Module):
def __init__(self):
super(Gnet, self).__init__()
self.conv1 = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=512, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
self.conv3 = nn.Sequential(
nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
self.conv4 = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.conv5 = nn.Sequential(
nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=5, stride=3, padding=1, bias=False),
nn.Tanh()
)
def forward(self, x):
y = self.conv1(x)
y = self.conv2(y)
y = self.conv3(y)
y = self.conv4(y)
y = self.conv5(y)
return y
if __name__ == '__main__':
save_params_path = r"params"
save_img_path = r"./img"
batchsize = 100
img_data = r"E:\Learnn\cartoonfaces" #img_data为卡通人物的路径
num_epoch = 500
random_num = 128
save_real_img_path = os.path.join(save_img_path,"real_img")
save_fake_img_path = os.path.join(save_img_path,"fake_img")
save_dparam_path = os.path.join(save_params_path, "d_self_net.pth")
save_gparam_path = os.path.join(save_params_path, "g_self_net.pth")
for path in [save_img_path, save_params_path, save_real_img_path, save_fake_img_path]:
if not os.path.exists(path):
os.mkdir(path)
data_loader = DataLoader(Sampling_data(img_data), batch_size=batchsize, shuffle=True, num_workers=4, drop_last=True)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
g_net = Gnet().to(device)
d_net = Dnet().to(device)
g_net.train()
d_net.train()
if os.path.exists(save_dparam_path and save_gparam_path): # 两个网络两个参数
d_net.load_state_dict(torch.load(save_dparam_path))
g_net.load_state_dict(torch.load(save_gparam_path))
print("两个参数已经加载成功!!!")
else:
print("NO Params!!!")
loss_fn = nn.BCELoss()
d_optimizer = torch.optim.Adam(d_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(g_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
for epoch in range(num_epoch):
for i, img in enumerate(data_loader):
real_img = img.to(device)
real_label = torch.ones(batchsize).view(-1, 1, 1, 1).to(device)
fake_label = torch.zeros(batchsize).view(-1, 1, 1, 1).to(device)
real_out = d_net(real_img)
d_loss_real = loss_fn(real_out, real_label)
rand_n = torch.randn(batchsize, random_num, 1, 1).to(device=device)
fake_img = g_net(rand_n)
fake_out = d_net(fake_img)
d_loss_fake = loss_fn(fake_out, fake_label)
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
rand_n1 = torch.randn(batchsize, random_num, 1, 1).to(device=device)
fake_img = g_net(rand_n1)
output = d_net(fake_img)
g_loss = loss_fn(output, real_label)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
if i%10 == 0:
print(real_out.data.mean(),fake_out.data.mean())
fake_imgs = (0.5*(fake_img.cpu().data+1)).clamp(0, 1)
real_imgs = (0.5 * (real_img.cpu().data + 1)).clamp(0, 1)
save_image(fake_imgs, os.path.join(save_fake_img_path,"{}_fake_imgs.jpg".format(epoch+1)),nrow=10,normalize=True,scale_each=True)
save_image(real_imgs, os.path.join(save_real_img_path,"{}_real_imgs.jpg".format(epoch+1)),nrow=10,normalize=True,scale_each=True)
torch.save(g_net.state_dict(), save_gparam_path)
torch.save(d_net.state_dict(), save_dparam_path)
三、效果展示:
1)生成的图像:
2)原始图像: