TensoFlow 实现 VAE(变分自编码)神经网络

本文介绍了使用TensorFlow构建变分自编码器(VAE)神经网络的过程,重点关注网络结构、损失函数和截断随机正态分布的参数对输出效果的影响。
摘要由CSDN通过智能技术生成

TensoFlow 实现变分自编码(VAE)神经网络

1. 网络结构

自上而下,先搭骨架,在实现

1. 主网络:net

          	1. init
        	1. 初始化编码器类
        	2. 初始化解码器类
          	2. forward
        	1. 获取编码器输出:$\delta $ 和 $\mu$(方差)
        	2. 获取一个标准正态分布
        	3. 获取log标准方差`tf.sqrt(tf.exp(方差))`
        	4. 获取解码器的输入(编码器的最后输出):decode_x = $\delta$ x 标准正态分布+ log标准方差
        	5. 获取解码的输出
          	3. backward
        	1. loss(1):编码器的输出-输入的平方差
        	2. loss(2):KL散度
        	3. 总loss:loss(1)+ loss(2)
             	1. 可以在损失前加一个变量,用于对单个损失的提升(重点损失)
          	4. 测试:获取解码器 生成的图片:
        	1. 生成一个标准的正态分布
        	2. 使用解码器做一个前向计算
  	2. 子网络:编码器

   ​	1. 返回$\delta $ 和 $log $方差
   ​	2. 一个输入,
   ​	3. 一次激活
   ​	4. 两个输出,两个权重,无激活
3. 子网络:解码器

   ​	1. 全连接
   ​	2. 一个输入
   ​	3. 一个输出
   ​	4. 无激活函数:因为所需输出不是概率而是图片

知识点:

  1. 方差和标准方差的关系
    1. 方差:无平方
    2. 标准方差:有平方
    3. 如何求损失
      1. 求图片的损失:生成图片和原有图片的损失
      2. 求正太分布的损失:匹配正态分布的损失
      3. 两者相加,优化
    4. 如何设初值
      1. 截断随机正态分布
  2. 对以下内容的更改对神经网络的损失、输出值的改变
    1. 截断随机正态分布的截断
    2. 损失函数
      1. 平方差
      2. KL散度(相对熵)
    3. 优化器的选择
      1. Adm:

实现代码1:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets(r'C:\Users\liev\Desktop\myproject\yin_test\MNIST_DATA_TensorFlow',one_hot=True)

# 编码类
class EncodeNet:
    def __init__(self):
        # 定义输入层到第一隐层的全连接权重和偏值
        self.encode_w = tf.Variable(tf.truncated_normal([784, 100], stddev=0.01))
        self.encode_b = tf.Variable(tf.zeros([100]))
        # 定义第一隐层delta 输出的权重
        self.log_var_w = tf.Variable(tf.truncated_normal([100,128], stddev=0.01))

        # 定义第一隐层mu 输出的权重
        self.mean_w = tf.Variable(tf.truncated_normal([100,128
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值