import tensorflow as tf
from tflearn.layers.conv import global_avg_pool
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.layers import batch_norm, flatten
from tensorflow.contrib.framework import arg_scope
import numpy as np
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# Hyperparameter
growth_k = 12
nb_block = 2 # how many (dense block + Transition Layer) ?
init_learning_rate = 1e-4
epsilon = 1e-8 # AdamOptimizer epsilon
dropout_rate = 0.2
# Momentum Optimizer will use
nesterov_momentum = 0.9
weight_decay = 1e-4
# Label & batch_size
class_num = 10
batch_size = 100
total_epochs = 50
def conv_layer(input, filter, kernel, stride=1, layer_name="conv"):
with tf.name_scope(layer_name):
network = tf.layers.conv2d(inputs=input, filters=filter, kernel_size=kernel, strides=stride, padding='SAME')
return network
def Global_Average_Pooling(x, stride=1):
"""
width = np.shape(x)[1]
height = np.shape(x)[2]
pool_size = [width, height]
return tf.layers.average_pooling2d(inputs=x, pool_size=pool_size, strides=stride) # The stride value does not matter
It is global average pooling without tflearn
"""
return global_avg_pool(x, name='Global_avg_pooling')
# But maybe you need to install h5py and curses or not
def Batch_Normalization(x, training, scope):
with arg_scope([batch_norm],
scope=scope,
updates_collections=None,-
decay=0.9,
center=True,
scale=True,
zero_debias_moving_mean=True) :
return tf.cond(training,
lambda : batch_norm(inputs=x, is_training=training, reuse=None),
lambda
DenseNet的tensorfolw实现,训练数据集MNIST
最新推荐文章于 2023-07-05 16:29:24 发布
本文详细介绍了如何使用TensorFlow框架实现深度学习模型DenseNet,并在经典MNIST手写数字数据集上进行训练。通过DenseNet的密集连接特性,提高模型的识别性能。
摘要由CSDN通过智能技术生成