import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch import optim
import os
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets,transforms #导入库以及工具包
from torchvision.utils import save_image
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
batch_size=64
learning_rate=0.001
num_epoches=20 #定义参数
data_tf=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]) #对数据集进行处理,将图片转换成tensor类型并且归一标准化
#
train_dataset=datasets.MNIST(root='./data',train=True,transform=data_tf,download=True)
test_dataset=datasets.MNIST(root='./data',train=False,transform=data_tf,download=True)
train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False) #通过python内部程序dataloader下载MNIST手写字体数据集
class autoencoder(nn.Module): #定义一个神经网络,作为识别分类
def __init__(self):
super(autoencoder, self).__init__() #初始化模型
self.encoder=nn.Sequential(nn.Linear(28*28,256),
nn.ReLU(True),
nn.Linear(256,64),
nn.ReLU(True),
nn.Linear(64,12),
nn.ReLU(True),
nn.Linear(12,3)
) #构造第一层网络
self.decoder=nn.Sequential(nn.Linear(3,12),
nn.ReLU(True),
nn.Linear(12,64),
nn.ReLU(True),
nn.Linear(64,128),
nn.ReLU(True),
nn.Linear(128,28*28),
nn.Tanh()
)#构造第二层网络
def forward(self,x): #前向传播过程
x=self.encoder(x)
x=self.decoder(x)
return x
model=autoencoder() #加载模型并输出参数
criterion=nn.MSELoss() #构造损失函数为交叉熵
optimizer=torch.optim.Adam(model.parameters(), lr=0.0003)
def train(epoch):
model.train()
all_loss = 0.
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to('cpu'), targets.to('cpu')
real_imgs = torch.flatten(inputs, start_dim=1)
# Train Discriminator
gen_imgs = model(real_imgs)
loss = criterion(gen_imgs,real_imgs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
all_loss += loss.item()
print('Epoch {}, loss: {:.6f}'.format(epoch, all_loss/(batch_idx+1)))
# Save generated images for every epoch
fake_images = gen_imgs.view(-1, 1, 28, 28)
save_image(fake_images,r'E:\Uface\fake_images-{}.png'.format(epoch + 1))
for epoch in range(20):
train(epoch)
最开始采用随机梯度下降法作为优化器,结果出来的图完全没有字体,噪声过多,后面改成了adam,实现了较好的效果