import os
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader,sampler,Dataset
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import save_image
import time
import numpy as np
from PIL import Image
from torch import optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
import datetime
import threading
transform = transforms.Compose(
[
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
nz=256
ngf=64
nc=3
class Generator(nn.Module):
"""生成器"""
def __init__(self):
super(Generator, self).__init__()
# 生成器结构
self.main = nn.Sequential(
# 输入大小:nz
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# 大小:(ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# 大小:(ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# 大小:(ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# 大小:(ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# 大小:(nc) x 64 x 64
)
def forward(self, input):
output = self.main(input)
return output
ndf = 16
class Discriminator(nn.Module):
"""鉴别器"""
def __init__(self):
super(Discriminator, self).__init__()
# 鉴别器的结构
self.main = nn.Sequential(
# 输入大小: (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
# 与生成器类似哟
if input.is_cuda and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
# 注意输出已经延展成一列的张量了
return output.view(-1, 1).squeeze(1)
class MyData(Dataset):
def __init__(self, is_train):
super(MyData, self).__init__()
self.root = '\\\\vdinas.ymtc.local\\perdata\\E908724\\Downloads\\dcgan_anime_avatars-master\\data\\'
self.path = self.root
def __getitem__(self, item):
imgs = os.listdir(self.path)
img = Image.open(self.path + imgs[item])
img = transform(img)
return img
def __len__(self):
return len(os.listdir(self.path))
def to_img(x):
'''
定义一个函数将最后的结果转换回图片
'''
x = 0.5 * (x + 1.)
x = x.clamp(0, 1)
x = x.view(x.shape[0], 3, 64, 64)
return x
if __name__ == '__main__':
num_epochs = 5
train_data = MyData(is_train=True)
data_set = DataLoader(dataset=train_data, batch_size=16, shuffle=True, drop_last=True)
epochs = 10000
netD = Discriminator()
netG = Generator()
criterion = nn.BCEWithLogitsLoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.0002)
optimizerG = optim.Adam(netG.parameters(), lr=0.0002)
# 判别器损失
discriminator_loss = 0
# 生成器损失
generator_loss = 0
label_true = torch.ones(16)
label_true = Variable(label_true)
label_fake = torch.zeros(16)
label_fake = Variable(label_fake)
fixed_noise = torch.randn(16, 256, 1, 1)
count = 0
for i in range(epochs):
for index, img in enumerate(data_set):
# 训练真图片loss
optimizerD.zero_grad()
output = netD(img)
d_loss_real = criterion(output, label_true)
d_loss_real.backward()
# 训练假图片loss
noise = fixed_noise.data.copy_(torch.randn(16, 256, 1, 1))
generated_images = netG(noise).detach()
d_loss_fake = criterion(netD(generated_images), label_fake)
d_loss_fake.backward()
optimizerD.step()
print(index, ' D ', d_loss_fake.data)
optimizerG.zero_grad()
noise = fixed_noise.data.copy_(torch.randn(16, 256, 1, 1))
generated_images = netG(noise)
g_loss_fake = criterion(netD(generated_images), label_true)
g_loss_fake.backward()
optimizerG.step()
print(index, ' G ', g_loss_fake.data)
if index % 50 == 0:
pic = to_img(generated_images.cpu().data)
if not os.path.exists('./simple_autoencoder'):
os.mkdir('./simple_autoencoder')
save_image(pic, './simple_autoencoder/image_{}_{}.png'.format(str(count), str(index)))
count += 1
CSDNL+CONV
于 2022-09-29 12:16:04 首次发布