原文链接: tensorboard 自编码分类网络和vgg19 网络结构可视化
上一篇: Python 使用http服务器 在局域网内分享文件
下一篇: mobilenet_v2_1.4_224 flowers 数据集分类网络
可视化网络结构,只需要执行下面代码即可
summary_writer = tf.summary.FileWriter('./log/', sess.graph)
tf.train.Saver 用来保存和加载模型参数
saver = tf.train.Saver()
saver.save(sess, './log/save')
saver.restore(sess, CKPT_PATH)
会在指定目录下生成event文件,按照时间排序,会显示最新的event中的网络结构
在event文件夹中执行下面代码,如果event文件有更新会自动重新加载,只需要刷新浏览器即可
tensorboard.exe --logdir=.
自编码分类网络可视化
可以看到第一个loss是自编码loss
第二个是分类网络的loss
import tensorflow as tf
import numpy as np
import tensorflow.contrib.slim as slim
import cv2 as cv
TRAIN_PATH = "D:/data/Digit Recognizer/train.csv"
TEST_PATH = "D:/data/Digit Recognizer/test.csv"
learning_rate = .0001
TRAIN_STEP = 30000
# 大样本
BATCH_SIZE = 512
TEST_SIZE = 512
SHOW_STEP = 100
# 增大特征数据量
IMAGE_SIZE = 32
in_x = tf.placeholder(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE, 1))
in_y = tf.placeholder(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE, 1))
in_y2 = tf.placeholder(tf.float32, (None, 10))
kernel = (5, 5)
# 自编码网络
# 浅一点的网络效果稍微好点。。。
# 层数比较多的话提取的信息时需要使用大样本128
# 导致后面全连接分类效果比较差,误差跳动很厉害
# relu 函数比relu6的效果好一点。。。
with slim.arg_scope(
[slim.conv2d, slim.conv2d_transpose],
activation_fn=tf.nn.relu,
# activation_fn=tf.nn.relu6,
):
with tf.variable_scope("encode"):
print(in_x.shape)
net = slim.conv2d(in_x, 32, kernel, stride=2)
print(net.shape)
net = slim.conv2d(net, 16, kernel, stride=2)
print(net.shape)
core = slim.conv2d(net, 8, kernel, stride=2)
print(core.shape)
with tf.variable_scope("decode"):
net = slim.conv2d_transpose(core, 16, kernel, stride=2)
print(net.shape)
net = slim.conv2d_transpose(net, 32, kernel, stride=2)
print(net.shape)
out_y = slim.conv2d_transpose(net, 1, kernel, stride=2)
print(out_y.shape)
# 自编码loss和train
with tf.variable_scope('loss1'):
loss_code = tf.reduce_mean((out_y - in_y) ** 2)
with tf.variable_scope('train1'):
train_code = tf.train.AdamOptimizer(learning_rate).minimize(loss_code)
# 分类网络
with tf.variable_scope("class"):
with slim.arg_scope(
[slim.fully_connected],
activation_fn=tf.nn.relu,
):
net2 = slim.flatten(core)
print(net2.shape)
net2 = slim.fully_connected(net2, 128)
print(net2.shape)
net2 = slim.fully_connected(net2, 64)
print(net2.shape)
net2 = slim.fully_connected(net2, 32)
print(net2.shape)
net2 = slim.fully_connected(net2, 10)
print(net2.shape)
# 交叉熵效果更好
# 均方误差0.6953125
with tf.variable_scope('loss2'):
loss2 = tf.reduce_mean((net2 - in_y2) ** 2)
# 交叉熵可以达到.98
# loss2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=in_y2, logits=net2))
with tf.variable_scope('train2'):
train2 = tf.train.AdamOptimizer(learning_rate).minimize(loss2)
with tf.variable_scope('predict'):
predict = tf.equal(tf.argmax(net2, 1), tf.argmax(in_y2, 1))
with tf.variable_scope('accuracy'):
accuracy = tf.reduce_mean(tf.cast(predict, tf.float32))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
summary_writer = tf.summary.FileWriter('./log/', sess.graph)
展开节点查看具体网络结构
vgg19 网络可视化
下载slim模块,加载网络,可视化网络结构
from nets.vgg import vgg_19
import tensorflow as tf
in_x = tf.placeholder(tf.float32, (None, 224, 224, 3))
logits, end_points = vgg_19(in_x)
with tf.Session()as sess:
sess.run(tf.global_variables_initializer())
summary_writer = tf.summary.FileWriter('./log/', sess.graph)
可以看到各个节点的输出