import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
"对数据做归一化处理"
transform = transforms.Compose([
transforms.ToTensor(), #0-1 channel,high,wide
transforms.Normalize(0.5,0.5) #归一化处理,MNIST图片灰度级在0~255,先每个灰度值除以255,像素值范围缩放到 [0, 1] 区间,在output = (input - mean) / std
])
"实例 是指一个类(Class)的具体对象,这个实例具有类定义的属性和方法,可以被操作和使用"
train_ds= torchvision.datasets.MNIST(r'C:\Users\Administrator\Desktop\图像生成代码\GAN网络简单应用\data',train=True,transform=transform,download=False)
dataloader = DataLoader(train_ds,batch_size=64,shuffle=True)
"random_noise[batch_size,100]>>gen_img[batch_size,1,28,28]"
class generator(nn.Module):
def __init__(self):
super(generator,self).__init__()
self.main = nn.Sequential(
nn.Linear(100,256),
nn.ReLU(),
nn.Linear(256,512),
nn.ReLU(),
nn.Linear(512,784),
nn.Tanh() #[exp(x)-exp(-x)]/[exp(x)+exp(-x)]
)
def forward(self,x):
img = self.main(x)
img = img.view(-1,1,28,28)
return img
#"输入为(1,28,28)的图片,输出为二分类的概率值,输出使用sigmoid激活0-1"
#"BCEloss计算交叉熵损失"
#"nn.LeakyReLU f(x):x>0输出0,如果x<0,输出a*x a表示一个很小的斜率 ,比如0,001"
"[batch_size,784]>>[batch_size,1], 每一个value属于0~1之间"
class Discriminator(nn.Module):
def __init__(self):
# 调用父类(nn.Module)的构造函数
super(Discriminator,self).__init__()
self.main = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(),
nn.Linear(512,256),
nn.LeakyReLU(),
nn.Linear(256,1),
nn.Sigmoid() #1/[1+exp(-x)]
)
def forward(self,x):
x=x.view(-1,784)
x= self.main(x)
return x
"[b,100]>>[b,1,28,28]>>[b,28,28],并且显示图"
def gen_img_plot(model, test_input):
# 使用生成器模型获取预测结果,并将其转换为NumPy数组
prediction = np.squeeze(model(test_input).detach().cpu().numpy())
# 创建一个绘图窗口,大小为(16, 16)
plt.figure(figsize=(16, 16))
# 循环遍历每个生成的图像,并在子图中显示
for i in range(prediction.shape[0]):
plt.subplot(4, 4, i + 1)
# 将图像的像素值范围从[-1, 1]转换为[0, 1],并绘制图像
plt.imshow((prediction[i] + 1) / 2)
plt.axis('off') # 关闭坐标轴显示
# 显示绘图结果
plt.show()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 生成一个在设备上(GPU或CPU)随机生成的输入张量
test_input = torch.randn([16, 100], device=device)
gen = generator().to(device)
dis = Discriminator().to(device)
d_optim = torch.optim.Adam(dis.parameters(),0.0001)
g_optim = torch.optim.Adam(gen.parameters(),0.0001)
loss_fn= torch.nn.BCELoss()
D_loss =[]
G_loss =[]
def loss_show(D_loss,G_loss):
plt.figure(figsize=(8,8))
if len(D_loss)==len(G_loss):
step = len(D_loss)
else:
print("Warning: Lengths of D_loss and G_loss are not equal.")
exit()
plt.plot(range(0,step),D_loss,label = 'Discriminator Loss',color = 'red')
plt.plot(range(0,step),G_loss,label = 'Generator Loss',color = 'blue')
plt.legend(['Discriminator Loss','Generator Loss'])
plt.xlabel('step',fontsize=14)
plt.ylabel('loss value',fontsize=14)
# 在每个点的位置添加文本标签
for i, (x, y_d,y_g) in enumerate(zip(range(0,step), D_loss,G_loss)):
plt.text(x, y_d, f'({y_d:.2f})', fontsize=8, color='red', ha='right', va='bottom')
plt.text(x, y_g, f'({y_g:.3f})', fontsize=8, color='red', ha='right', va='bottom')
plt.title('Discriminator and Generator Loss Over Steps')
plt.show()
#循环训练
for epoch in range(20):
d_epoch_loss = 0
g_epoch_loss = 0
count = len(dataloader) # 数据集在分批后的批次数目,共有938个批次
# 遍历数据加载器的每个批次
for step, (img, _) in enumerate(dataloader):
img = img.to(device)
batch_size = img.size(0)
random_noise = torch.randn([batch_size, 100], device=device)
# 判别器优化
d_optim.zero_grad() # 优化器梯度清零
real_output = dis(img) # 判别器输入真实的图片,real_output对真实的图片进行预测结果,real_output[batch_size, 1]
d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 判别器在真实图像上的损失,判别器器希望对真实图片的预测为:真
d_real_loss.backward()
gen_img = gen(random_noise) # random_noise[batch_size, 100] >> gen_img[batch_size, 1, 28, 28]
fake_output = dis(gen_img.detach()) # 判别器输入生成的图片,对生成图片的预测 gen_img[batch_size, 1, 28, 28] >> fake_output[batch_size, 1]
d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 判别器在生成图像上的损失,判别器器希望对生成图片的预测为:假
d_fake_loss.backward()
d_loss = d_real_loss + d_fake_loss
d_optim.step()
# 生成器优化
g_optim.zero_grad()
fake_output = dis(gen_img)#判别器经过上面的:d_optim.step()优化参数后,在对生成的图片进行预测,gen_img[batch_size, 1, 28, 28] >> fake_output[batch_size, 1]
g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) # 生成器的损失,生成器希望生成的图片为:真
g_loss.backward()
g_optim.step()
# 一个epoch中总的loss
with torch.no_grad():
d_epoch_loss += d_loss
g_epoch_loss += g_loss
with torch.no_grad():
d_epoch_loss /= count
g_epoch_loss /= count
D_loss.append(d_epoch_loss.detach().cpu().numpy())
G_loss.append(g_epoch_loss.detach().cpu().numpy())
print(f'Epoch: {epoch},D_loss: {D_loss[epoch].item():.3f},G_loss: {G_loss[epoch].item():.3f}')
gen_img_plot(gen, test_input)
loss_show(D_loss,G_loss)