python降维可视化 自编码_PyTorch 高级篇(2):变分自编码器(Variational Auto

变分自编码器 学习资料

自编码器有这些个作用,

数据去噪(去噪编码器)

可视化降维

生成数据(与GAN各有千秋)

PyTorch 实现

预处理1

2

3

4

5

6

7

8import os

import torch

import torch.nn as nn

import torch.nn.functional as F

import torchvision

from torchvision import transforms

from torchvision.utils import save_image

1

2

3# 设备配置

torch.cuda.set_device(1) # 这句用来设置pytorch在哪块GPU上运行

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

1

2

3

4# 如果没有文件夹就创建一个文件夹

sample_dir = 'samples'

if not os.path.exists(sample_dir):

os.makedirs(sample_dir)

1

2

3

4

5

6

7

8# 超参数设置

# Hyper-parameters

image_size = 784

h_dim = 400

z_dim = 20

num_epochs = 15

batch_size = 128

learning_rate = 1e-3

MINIST 数据集1

2

3

4

5

6

7

8

9dataset = torchvision.datasets.MNIST(root='../../../data/minist',

train=True,

transform=transforms.ToTensor(),

download=True)

# 数据加载器

data_loader = torch.utils.data.DataLoader(dataset=dataset,

batch_size=batch_size,

shuffle=True)

创建VAE模型(变分自编码器(Variational Auto-Encoder))1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32# VAE model

class (nn.Module):

def __init__(self, image_size=784, h_dim=400, z_dim=20):

super(VAE, self).__init__()

self.fc1 = nn.Linear(image_size, h_dim)

self.fc2 = nn.Linear(h_dim, z_dim) # 均值 向量

self.fc3 = nn.Linear(h_dim, z_dim) # 保准方差 向量

self.fc4 = nn.Linear(z_dim, h_dim)

self.fc5 = nn.Linear(h_dim, image_size)

# 编码过程

def encode(self, x):

h = F.relu(self.fc1(x))

return self.fc2(h), self.fc3(h)

# 随机生成隐含向量

def reparameterize(self, mu, log_var):

std = torch.exp(log_var/2)

eps = torch.randn_like(std)

return mu + eps * std

# 解码过程

def decode(self, z):

h = F.relu(self.fc4(z))

return F.sigmoid(self.fc5(h))

# 整个前向传播过程:编码-》解码

def forward(self, x):

mu, log_var = self.encode(x)

z = self.reparameterize(mu, log_var)

x_reconst = self.decode(z)

return x_reconst, mu, log_var

1

2# 实例化一个模型

model = VAE().to(device)

1

2# 创建优化器

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

开始训练1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32for epoch in range(num_epochs):

for i, (x, _) in enumerate(data_loader):

# 获取样本,并前向传播

x = x.to(device).view(-1, image_size)

x_reconst, mu, log_var = model(x)

# 计算重构损失和KL散度(KL散度用于衡量两种分布的相似程度)

# KL散度的计算可以参考论文或者文章开头的链接

reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)

kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

# 反向传播和优化

loss = reconst_loss + kl_div

optimizer.zero_grad()

loss.backward()

optimizer.step()

if (i+1) % 100 == 0:

print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"

.format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))

# 利用训练的模型进行测试

with torch.no_grad():

# 随机生成的图像

z = torch.randn(batch_size, z_dim).to(device)

out = model.decode(z).view(-1, 1, 28, 28)

save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

# 重构的图像

out, _, _ = model(x)

x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)

save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py:1006: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.

warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")

/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py:52: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.

warnings.warn(warning.format(ret))

Epoch[1/15], Step [100/469], Reconst Loss: 9898.7285, KL Div: 3231.0195

Epoch[1/15], Step [200/469], Reconst Loss: 9985.5391, KL Div: 3290.1267

Epoch[1/15], Step [300/469], Reconst Loss: 9800.6211, KL Div: 3201.4980

Epoch[1/15], Step [400/469], Reconst Loss: 9444.1016, KL Div: 3259.1062

Epoch[2/15], Step [100/469], Reconst Loss: 9204.6201, KL Div: 3056.4475

Epoch[2/15], Step [200/469], Reconst Loss: 9729.0078, KL Div: 3206.0845

Epoch[2/15], Step [300/469], Reconst Loss: 9609.4307, KL Div: 3220.1729

Epoch[2/15], Step [400/469], Reconst Loss: 9514.4150, KL Div: 3206.0166

Epoch[3/15], Step [100/469], Reconst Loss: 9042.1270, KL Div: 3145.2937

Epoch[3/15], Step [200/469], Reconst Loss: 9773.1826, KL Div: 3235.4180

Epoch[3/15], Step [300/469], Reconst Loss: 9427.7207, KL Div: 3141.4922

Epoch[3/15], Step [400/469], Reconst Loss: 9658.2725, KL Div: 3235.2390

Epoch[4/15], Step [100/469], Reconst Loss: 9596.0439, KL Div: 3177.3101

Epoch[4/15], Step [200/469], Reconst Loss: 9158.8330, KL Div: 3114.7456

Epoch[4/15], Step [300/469], Reconst Loss: 9519.2754, KL Div: 3100.6924

Epoch[4/15], Step [400/469], Reconst Loss: 9318.7393, KL Div: 3098.9333

Epoch[5/15], Step [100/469], Reconst Loss: 9248.7139, KL Div: 3203.3230

Epoch[5/15], Step [200/469], Reconst Loss: 9914.3438, KL Div: 3244.7737

Epoch[5/15], Step [300/469], Reconst Loss: 9575.4922, KL Div: 3210.8545

Epoch[5/15], Step [400/469], Reconst Loss: 9519.7637, KL Div: 3243.2603

................................

Epoch[11/15], Step [400/469], Reconst Loss: 9872.5010, KL Div: 3267.5239

Epoch[12/15], Step [100/469], Reconst Loss: 9508.9053, KL Div: 3069.8406

Epoch[12/15], Step [200/469], Reconst Loss: 9340.8848, KL Div: 3093.4531

Epoch[12/15], Step [300/469], Reconst Loss: 9537.1279, KL Div: 3208.4387

Epoch[12/15], Step [400/469], Reconst Loss: 9205.0615, KL Div: 3125.3406

Epoch[13/15], Step [100/469], Reconst Loss: 9650.2803, KL Div: 3167.0171

Epoch[13/15], Step [200/469], Reconst Loss: 9609.6025, KL Div: 3179.3223

Epoch[13/15], Step [300/469], Reconst Loss: 9498.6650, KL Div: 3309.2681

Epoch[13/15], Step [400/469], Reconst Loss: 9823.6318, KL Div: 3218.4116

Epoch[14/15], Step [100/469], Reconst Loss: 9167.9990, KL Div: 3097.4619

Epoch[14/15], Step [200/469], Reconst Loss: 9712.9277, KL Div: 3222.7612

Epoch[14/15], Step [300/469], Reconst Loss: 9887.4297, KL Div: 3336.3618

Epoch[14/15], Step [400/469], Reconst Loss: 9485.8965, KL Div: 3180.0781

Epoch[15/15], Step [100/469], Reconst Loss: 9628.2295, KL Div: 3244.9995

Epoch[15/15], Step [200/469], Reconst Loss: 9556.5020, KL Div: 3147.9658

Epoch[15/15], Step [300/469], Reconst Loss: 9569.2588, KL Div: 3193.5071

Epoch[15/15], Step [400/469], Reconst Loss: 9334.9570, KL Div: 3074.2688

结果展示1

2

3

4#导入包

import matplotlib.pyplot as plt # plt 用于显示图片

import matplotlib.image as mpimg # mpimg 用于读取图片

import numpy as np

重构图1

2

3

4

5reconsPath = './samples/reconst-55.png'

Image = mpimg.imread(reconsPath)

plt.imshow(Image) # 显示图片

plt.axis('off') # 不显示坐标轴

plt.show()

随机生成图1

2

3

4

5genPath = './samples/sampled-107.png'

Image = mpimg.imread(genPath)

plt.imshow(Image) # 显示图片

plt.axis('off') # 不显示坐标轴

plt.show()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值