1.导入所需库
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
2. 训练集
# mini_batch size
mb_size=64
#translate data to tensor format which is pytorch's expected format
transforms=transforms.Compose([transforms.ToTensor()])
#训练集
trainset= torchvision.datasets.MNIST(root='./NewData',download=False,train=True,transform=transforms)
trainloader=torch.utils.data.DataLoader(trainset,shuffle=True,batch_size=mb_size)
Notes
- torchvision.transforms是pytorch中的图像预处理包。
一般用Compose()把多个步骤整合到一起,例如:
transforms.Compose([
transforms.CenterCrop(10),
transforms.ToTensor(),
])
此外,常用的transforms中的函数:
Resize:把给定的图片resize到given size
Normalize:Normalized an tensor image with mean and standard deviation
ToTensor:convert a PIL image to tensor (HWC) in range [0,255] to a
torch.Tensor(C * H * W) in the range [0.0,1.0]ToPILImage: convert a tensor to PIL image
参考:https://blog.csdn.net/ftimes/article/details/105202795
3.可视化
参考:https://blog.csdn.net/xiongchengluo1129/article/details/79078478
#define an iterator
data_iter=iter(trainloader)
#getting the next batch of the image and labels
images,labels=data_iter.next()
test=images.view(images.size(0),-1)
print(test.size())
#dims and learning rate
z_dim=100
x_dim=test.size(1)
h_dim=128
lr=0.003
def imshow(img):
#拼接图片
im=torchvision.utils.make_grid(img)
#转化成numpy
npimg=im.numpy()
plt.figure(figsize=(8,8))
plt.imshow(np.transpose(npimg,(1,2,0)))
plt.xticks([])
plt.yticks([])
plt.show()
imshow(images)
输出:
Notes:
-
将多维度的tensor展平成一维,x.view(x.size(0), -1)就实现的这个功能。
所以我们一个batch里面的64张图,图片的大小是28 * 28,输出的size为64 * 784。 -
make_grid的作用是将若干幅图像拼成一幅图像。其中padding的作用就是子图像与子图像之间的pad有多宽。
-
plt.figure()语法
figure(num=None, figsize=None, dpi=None, facecolor=None)
edgecolor=None, frameon=True)
num:图像编号或名称,数字为编号 ,字符串为名称
figsize:指定figure的宽和高,单位为英寸;
dpi: 指定绘图对象的分辨率,即每英寸多少个像素,缺省值为80 1英寸等于2.5cm,A4纸是 21*30cm的纸张
facecolor:背景颜色
edgecolor:边框颜色
frameon:是否显示边框
- np.transpose(img,(1,2,0))将图片的格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),这样plt.show()就可以显示图片了。
- plt.xticks()用法参考:https://blog.csdn.net/Tenderness___/article/details/82972845
4. 初始化weight和bias
def init_weights(m):
if type(m)==nn.Linear:
#初始化权重
nn.init.xavier_uniform(m.weight)
# bias都设为0
m.bias.data.fill_(0)
参考:https://blog.csdn.net/dss_dssssd/article/details/83959474‘
pytorch官方教程中的例子:
5. Generator and Discriminator
class Generate(nn.Module):
def __init__(self):
super(Generate,self).__init__()
self.predict=nn.Sequential(
nn.Linear(z_dim,h_dim),
nn.ReLU(),
nn.Linear(h_dim,x_dim),
nn.Sigmoid()
)
self.predict.apply(init_weights)
def forward(self,input):
return self.predict(input)
class Dis(nn.Module):
def __init__(self):
super(Dis,self).__init__()
self.predict =nn.Sequential(
nn.Linear(x_dim,h_dim),
nn.ReLU(),
nn.Linear(h_dim,1),
nn.Sigmoid()
)
self.predict.apply(init_weights)
def forward(self,input):
return self.predict(input)
G=Generate()
D=Dis()
6.Optimizer
G_solver=optim.Adam(G.parameters(),lr=lr)
D_solver=optim.Adam(D.parameters(),lr=lr)
7.Training
for epoch in range(2):
G_loss_run=0.0
D_loss_run=0.0
for i,data in enumerate(trainloader):
# data里面包含图像数据(inputs)(tensor类型的)和标签(labels)(tensor类型)。
X,label=data
mb_size=X.size(0)
X=X.view(X.size(0),-1)
one_labels=torch.ones(mb_size,1)
zero_labels=torch.zeros(mb_size,1)
z=torch.randn(mb_size,z_dim)
G_samples=G(z)
D_fake=D(G_samples)
D_real=D(X)
D_fake_loss=F.binary_cross_entropy(D_fake,zero_labels)
D_real_loss=F.binary_cross_entropy(D_real,one_labels)
D_loss=D_fake_loss+D_real_loss
D_solver.zero_grad()
D_loss.backward(retain_graph=True)
D_solver.step()
z=torch.rand(mb_size,z_dim)
G_sample=G(z)
D_fake=D(G_samples)
G_loss=F.binary_cross_entropy(D_fake,one_labels)
G_solver.zero_grad()
G_loss.backward()
G_solver.step()
print('Epoch: {}, G_loss: {}. D_loss:{}'.format(epoch,G_loss_run/(i+1),D_loss_run/(i+1)))
samples=G(z).detach()
samples=samples.view(mb_size,1,28,28)
imshow(samples)
完整代码:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
torch.manual_seed(0)
mb_size=64
#translate data to tensor format which is pytorch's expected format
transforms=transforms.Compose([transforms.ToTensor()])
#训练集
trainset= torchvision.datasets.MNIST(root='./NewData',download=False,train=True,transform=transforms)
trainloader=torch.utils.data.DataLoader(trainset,shuffle=True,batch_size=mb_size)
#可视化
#define an iterator
dataiter = iter(trainloader)
#getting the next batch of the image and labels
imgs, labels = dataiter.next()
test=imgs.view(imgs.size(0),-1)
print(test.size())
h_dim = 128 # number of hidden neurons in our hidden layer
Z_dim = 100 # dimension of the input noise for generator
lr = 1e-3 # learning rate
X_dim = imgs.view(imgs.size(0), -1).size(1)
print(X_dim)
def imshow(img):
im=torchvision.utils.make_grid(img)
npimg=im.numpy()
plt.figure(figsize=(8,8))
plt.imshow(np.transpose(npimg,(1,2,0)))
plt.xticks([])
plt.yticks([])
plt.show()
imshow(imgs)
def xavier_init(m):
""" Xavier initialization """
if type(m) == nn.Linear:
nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0)
class Gen(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(Z_dim, h_dim),
nn.ReLU(),
nn.Linear(h_dim, X_dim),
nn.Sigmoid()
)
self.model.apply(xavier_init)
def forward(self, input):
return self.model(input)
class Dis(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(X_dim, h_dim),
nn.ReLU(),
nn.Linear(h_dim, 1),
nn.Sigmoid()
)
self.model.apply(xavier_init)
def forward(self, input):
return self.model(input)
test = Dis()
print(test)
test = Dis()
print(test)
G=Gen()
D=Dis()
G_solver=optim.Adam(G.parameters(),lr=lr)
D_solver=optim.Adam(D.parameters(),lr=lr)
for epoch in range(20):
G_loss_run = 0.0
D_loss_run = 0.0
for i, data in enumerate(trainloader):
X, _ = data
X = X.view(X.size(0), -1)
mb_size = X.size(0)
# Definig labels for real (1s) and fake (0s) images
one_labels = torch.ones(mb_size, 1)
zero_labels = torch.zeros(mb_size, 1)
# Random normal distribution for each image
z = torch.randn(mb_size, Z_dim)
# Feed forward in discriminator both
# fake and real images
D_real = D(X)
# fakes = G(z)
D_fake = D(G(z))
# Defining the loss for Discriminator
D_real_loss = F.binary_cross_entropy(D_real, one_labels)
D_fake_loss = F.binary_cross_entropy(D_fake, zero_labels)
D_loss = D_fake_loss + D_real_loss
# backward propagation for discriminator
D_solver.zero_grad()
D_loss.backward()
D_solver.step()
# Feed forward for generator
z = torch.randn(mb_size, Z_dim)
D_fake = D(G(z))
# loss function of generator
G_loss = F.binary_cross_entropy(D_fake, one_labels)
# backward propagation for generator
G_solver.zero_grad()
G_loss.backward()
G_solver.step()
G_loss_run += G_loss.item()
D_loss_run += D_loss.item()
# printing loss after each epoch
print('Epoch:{}, G_loss:{}, D_loss:{}'.format(epoch, G_loss_run / (i + 1), D_loss_run / (i + 1)))
# Plotting fake images generated after each epoch by generator
samples = G(z).detach()
samples = samples.view(samples.size(0), 1, 28, 28)
imshow(samples)