import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
image_size = 64 #图片大小
batch_size = 128 #批量大小
def to_img(x):
out = 0.5 * (x + 1)
out = out.clamp(0, 1)
out = out.view(-1, 3, 64, 64)
return out
"""
数据加载
ImageFolder: 通用数据加载器
dset.ImageFolder(root="root folder path", [transform, target_transform])
transforms.Compose:构造一个转换列表
transforms.Resize:调整图像大小,支持4种插值
transforms.CenterCrop:以图片中心切割一个正方形
transforms.ToTensor:图片转tensor
transforms.Normalize: class torchvision.transforms.Normalize(mean, std) 归一化,前者是方差,后者是均值
"""
data_path = os.path.abspath("E:/dataset/CelebA/Img")
#print (os.listdir(data_path))
#注意:这里文件夹默认路径要求labelXX:1.jpg,因此要修改文件路径否则报错
#Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp
dataset = datasets.ImageFolder(root=data_path,transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,shuffle=True)
"""
生成器
nn.ConvTranspose2d(in_channels,
out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True)
二维转置卷积操作:转置卷积是一类上采样操作,类似插值
"""
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.gen = nn.Sequential(
nn.ConvTranspose2d(100, image_size * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(image_size * 8),
nn.ReLU(True),
nn.ConvTranspose2d(image_size * 8, image_size * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(image_size * 4),
nn.ReLU(True),
nn.ConvTranspose2d(image_size * 4, image_size * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(image_size * 2),
nn.ReLU(True),
nn.ConvTranspose2d( image_size * 2, image_size, 4, 2, 1, bias=False),
nn.BatchNorm2d(image_size),
nn.ReLU(True),
nn.ConvTranspose2d(image_size, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
x = self.gen(x)
return x
"""
判别器模型
"""
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.dis = nn.Sequential(
nn.Conv2d(3, image_size, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(image_size, image_size * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(image_size * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(image_size * 2, image_size * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(image_size * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(image_size * 4, image_size * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(image_size * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(image_size * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
x = self.dis(x)
return x
"""
实例化网络
"""
d_learning_rate = 3e-4 # 3e-4
g_learning_rate = 3e-4
optim_betas = (0.9, 0.999)
criterion = nn.BCELoss() #损失函数 - 二进制交叉熵
G = Generator()
D = Discriminator()
if torch.cuda.is_available():
print("use cuda")
D = D.cuda()
G = G.cuda()
g_optimizer = optim.Adam(G.parameters(), lr=d_learning_rate)
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate)
num_epochs = 100 #循环次数
"""
开始训练
"""
for epoch in range(num_epochs):
for index, imgs in enumerate(dataloader, 0):
#步骤1,训练判别器
img_data = imgs[0]
train_batch_size = img_data.size(0)
#img_data = img_data.view(train_batch_size, -1)
d_real_decision = D(Variable(img_data).cuda())
d_real_label = torch.full((train_batch_size,), 1).cuda()
d_real_error = criterion(d_real_decision, d_real_label)
d_fake_input = Variable(torch.randn(train_batch_size, 100, 1, 1)).cuda()
d_fake_imgs = G(d_fake_input).detach()
d_fake_decision = D(d_fake_imgs)
d_fake_decision.cuda()
d_fake_label = torch.full((train_batch_size,), 0).cuda()
d_fake_error = criterion(d_fake_decision, d_fake_label)
d_loss = d_fake_error + d_real_error
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
#步骤2,训练生成器
#g_fake_input = Variable(torch.randn(train_batch_size, 100, 1, 1)).cuba()
g_fake_imgs = G(d_fake_input)
g_fake_decision = D(g_fake_imgs).view(-1)
g_fake_decision.cuda()
g_fake_label = torch.full((train_batch_size,), 1).cuda()
g_fake_error = criterion(g_fake_decision, g_fake_label)
g_optimizer.zero_grad()
g_fake_error.backward()
g_optimizer.step()
if (index + 1) % 200 == 0 or index == 0:
print("Epoch[{}/{}]".format(epoch,index))
real_images = to_img(g_fake_imgs.data)
save_image(real_images, './img/dcgan/test.png')
print("Epoch[{}],d_loss:{:.6f}".format(epoch,d_loss.data.item()))