import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
batch_size = 64
noise_dim = 96
path = os.getcwd()
path = os.path.join(path, 'mnist')
# print(path)
mnist_train = dset.MNIST(path, train=True, download=True, transform=T.ToTensor())
loader_train = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
# imgs = loader_train.__iter__().next()[0].view(batch_size, 784).numpy().squeeze()
# show_images(imgs)
def show_images(images):
images = np.reshape(images, [images.shape[0], -1]) # images reshape to (batch_size, D)
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg,sqrtimg]))
# plt.imsave('a'+str(i)+'.jpg', img.reshape([sqrtimg,sqrtimg]))
return
#生成噪声
def sample_noise(batch_size, dim):
"""
Generate a PyTorch Tensor of uniform random noise.
Input:
- batch_size: Integer giving the batch size of noise to generate.
- dim: Integer giving the dimension of noise to generate.
Output:
- A PyTorch Tensor of shape (batch_size, dim) containing uniform
random noise in the range (-1, 1).
"""
temp = torch.rand(batch_size, dim) + torch.rand(batch_size, dim) * (-1)
return temp
# 拉直
class Flatten(nn.Module):
def forward(self, x):
N, C, H, W = x.size() # read in N, C, H, W
return x.view(N, -1) # "flatten" the C * H * W values into a single vector per image
# 处理成符合conv输入的
class Unflatten(nn.Module):
"""
An Unflatten module receives an input of shape (N, C*H*W) and reshapes it
to produce an output of shape (N, C, H, W).
"""
def __init__(self, N=-1, C=128, H=7, W=7):
super(Unflatten, self).__init__()
self.N = N
self.C = C
self.H = H
self.W = W
def forward(self, x):
return x.view(self.N, self.C, self.H, self.W)
class generator(nn.Module):
def __init__(self, noise_dim=noise_dim):
super(generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(noise_dim, 1024),
nn.ReLU(True),
nn.BatchNorm1d(1024),
nn.Linear(1024, 7 * 7 * 128),
nn.ReLU(True),
nn.BatchNorm1d(7 * 7 * 128)
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
nn.ReLU(True),
nn.BatchNorm2d(64),
nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
nn.Tanh()
)
def forward(self, x):
x = self.fc(x)
x = x.view(x.shape[0], 128, 7, 7) # reshape 通道是 128,大小是 7x7
x = self.conv(x)
return x
# 定义损失
def generator_loss(scores_fake):
loss = 0.5 * ((scores_fake - 1) ** 2).mean()
return loss
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 5, 1),
nn.LeakyReLU(0.01),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 5, 1),
nn.LeakyReLU(0.01),
nn.MaxPool2d(2, 2)
)
self.fc = nn.Sequential(
nn.Linear(1024, 1024),
nn.LeakyReLU(0.01),
nn.Linear(1024, 1)
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
def discriminatro_loss(scores_real, scores_fake):
loss = 0.5 * ((scores_real - 1) ** 2).mean() + 0.5 * (scores_fake ** 2).mean()
return loss
# 定义优化器
def get_optim(model):
# betas (Tuple[float, float], 可选) – 用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)
# weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认: 0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.5, 0.999))
return optimizer
def train(G,D,G_loss,D_loss,G_optim,D_optim,batch_size=64,noise_size=96):
for epoch in range(10):
print('epoch',epoch)
i = 1
for x,_ in loader_train:
# 不满足一批就跳过
if len(x) != batch_size:
continue
# 训练判别器
#真数据
D_optim.zero_grad()
realdata = Variable(x).to(device)
logits_real = D(2* (realdata - 0.5)) # 变为0-1
#假数据
g_fake_seed = Variable(sample_noise(batch_size, noise_size)).to(device)
fake_images = G(g_fake_seed).detach()
logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
#反向传播
d_total_error = D_loss(logits_real, logits_fake)
d_total_error.backward()
D_optim.step()
#更新生成器
G_optim.zero_grad()
g_fake_seed = Variable(sample_noise(batch_size, noise_size)).to(device)
fake_images = G(g_fake_seed)
gen_logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
g_error = G_loss(gen_logits_fake)
g_error.backward()
G_optim.step()
if i % 100 == 0:
print('Iter: {}, D: {:.4}, G:{:.4}'.format(epoch, d_total_error, g_error))
print(i)
break
i += 1
imgs_numpy = fake_images.data.cpu().numpy()
show_images(imgs_numpy[0:16])
plt.show()
# print()
torch.save(G.state_dict(),'g.pth')
print('模型已保存')
if __name__ == '__main__':
'''
训练用代码
'''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# G = generator().to(device)
# D = discriminator().to(device)
#
# D_optim = get_optim(D)
# G_optim = get_optim(G)
# train(G,D,generator_loss,discriminatro_loss,G_optim,D_optim)
'''
生成
'''
model = generator().to(device)
model.load_state_dict(torch.load('g.pth'))
print(model)
noise = sample_noise(batch_size, 96)
img = model(noise)
plt.imsave('a.jpg',img.data.cpu().numpy()[0][0])
结果: