import os
import numpy as np
import tensorflow as tf
from scipy import io
from tensorflow.examples.tutorials.mnist import input_data
# 1、设置超参数
learning_rate = 0.001
epochs = 10
batch_size = 128
test_valid_size = 512 # 用于验证或者测试的样本数量。
n_classes = 10
keep_probab = 0.75
def conv2d_block(input_tensor, filter_w, filter_b, stride=1):
"""
实现 卷积 + 偏置项相加 + 激活
:param input_tensor:
:param filter_w:
:param filter_b:
:param stride:
:return:
"""
conv = tf.nn.conv2d(
input=input_tensor, filter=filter_w, strides=[1, stride, stride, 1], padding='SAME'
)
conv = tf.nn.bias_add(conv, filter_b)
conv = tf.nn.relu6(conv)
return conv
def maxpool(input_tensor, k=2):
"""
池化
:param input_tensor:
:param k:
:return:
"""
ksize = [1, k, k, 1]
strides = [1, k, k, 1]
max_out = tf.nn.max_pool(
value=input_tensor, ksize=ksize, strides=strides, padding='SAME'
)
return max_out
def model(input_tensor, keep_prob, pre_trained_weights=None):
"""
:param input_tensor: 输入图片的占位符
:param weights:
:param biases:
:param keep_prob: 保留概率的占位符
:return:
"""
"""
'w_conv1:0', 'w_conv2:0', 'w_fc1:0', 'w_logits:0',
'b_conv1:0', 'b_conv2:0', 'b_fc1:0', 'b_logits:0']
"""
if pre_trained_weights:
W = pre_trained_weights
weights = {
'conv1': tf.get_variable('w_conv1', dtype=tf.float32,
initializer=W['w_conv1:0'], trainable=False),
'conv2': tf.get_variable('w_conv2', dtype=tf.float32,
initializer=W['w_conv2:0'], trainable=False),
'fc1': tf.get_variable('w_fc1', dtype=tf.float32,
initializer=W['w_fc1:0'], trainable=True),
'logits': tf.get_variable('w_logits', dtype=tf.float32,
initializer=W['w_logits:0'], trainable=True),
}
biases = {
'conv1': tf.get_variable('b_conv1', dtype=tf.float32,
initializer=np.reshape(W['b_conv1:0'], -1), trainable=False),
'conv2': tf.get_variable('b_conv2', dtype=tf.float32,
initializer=np.reshape(W['b_conv2:0'], -1), trainable=False),
'fc1': tf.get_variable('b_fc1', shape=[1024], dtype=tf.float32,
initializer=tf.zeros_initializer()),
'logits': tf.get_variable('b_logits', shape=[n_classes], dtype=tf.float32,
initializer=tf.zeros_initializer()),
}
else:
weights = {
'conv1': tf.get_variable('w_conv1', shape=[5, 5, 1, 32], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.1)),
'conv2': tf.get_variable('w_conv2', shape=[5, 5, 32, 64], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.1)),
'fc1': tf.get_variable('w_fc1', shape=[7 * 7 * 64, 1024], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.1)),
'logits': tf.get_variable('w_logits', shape=[1024, n_classes], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.1)),
}
biases = {
'conv1': tf.get_variable('b_conv1', shape=[32], dtype=tf.float32,
initializer=tf.zeros_initializer()),
'conv2': tf.get_variable('b_conv2', shape=[64], dtype=tf.float32,
initializer=tf.zeros_initializer()),
'fc1': tf.get_variable('b_fc1', shape=[1024], dtype=tf.float32,
initializer=tf.zeros_initializer()),
'logits': tf.get_variable('b_logits', shape=[n_classes], dtype=tf.float32,
initializer=tf.zeros_initializer()),
}
# 1、卷积1 [N, 28, 28, 1] ---> [N, 28, 28, 32]
conv1 = conv2d_block(
input_tensor=input_tensor, filter_w=weights['conv1'], filter_b=biases['conv1']
)
# 2、池化1 [N, 28, 28, 32] --->[N, 14, 14, 32]
pool1 = maxpool(conv1, k=2)
# 3、卷积2 [N, 14, 14, 32] ---> [N, 14, 14,64]
conv2 = conv2d_block(
input_tensor=pool1, filter_w=weights['conv2'], filter_b=biases['conv2']
)
conv2 = tf.nn.dropout(conv2, keep_prob=keep_prob)
# 4、池化1 [N, 14, 14,64] --->[N, 7, 7, 64]
pool2 = maxpool(conv2, k=2)
# 5、拉平层(flatten) [N, 7, 7, 64] ---> [N, 7*7*64]
x_shape = pool2.get_shape()
flatten_shape = x_shape[1] * x_shape[2] * x_shape[3]
flatted = tf.reshape(pool2, shape=[-1, flatten_shape])
# 6、FC1 全连接层
fc1 = tf.nn.relu6(tf.matmul(flatted, weights['fc1']) + biases['fc1'])
fc1 = tf.nn.dropout(fc1, keep_prob=keep_prob)
# 7、logits层
logits = tf.add(tf.matmul(fc1, weights['logits']), biases['logits'])
with tf.variable_scope('prediction'):
prediction = tf.argmax(logits, axis=1)
return logits, prediction
def create_dir_path(path):
if not os.path.exists(path):
os.makedirs(path)
print('create file path:{}'.format(path))
def store_weights(sess, save_path):
# todo 1、获取所有需要持久化的变量
# vars_list = tf.global_variables()
vars_list = tf.trainable_variables()
# 2、执行得到变量的值
vars_values = sess.run(vars_list)
# todo 3、将变量转换为字典对象
mdict = {}
for values, var in zip(vars_values, vars_list):
# 获取变量的名字
name = var.name
# 赋值
mdict[name] = values
# todo 4、保存为matlab数据格式
io.savemat(save_path, mdict)
print('Saved Vars to files:{}'.format(save_path))
def train():
# 创建持久化文件夹
checkpoint_dir = './model/mnist/matlab/ai20'
create_dir_path(checkpoint_dir)
graph = tf.Graph()
with graph.as_default():
# 1、占位符
x = tf.placeholder(tf.float32, [None, 28, 28, 1], name='x')
y = tf.placeholder(tf.float32, [None, 10], name='y')
keep_prob = tf.placeholder_with_default(0.75, shape=None, name='keep_prob')
# 2、创建模型图
weights_path = './model/mnist/matlab/ai20'
files = os.listdir(weights_path)
if files:
weight_file = os.path.join(weights_path, files[0])
if os.path.isfile(weight_file):
mdict = io.loadmat(weight_file)
logits, prediction = model(x, keep_prob, pre_trained_weights=mdict)
print('Load old model continue to train!')
else:
logits, prediction = model(x, keep_prob)
print('No old model, train from scratch!')
# 3、损失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
logits=logits, labels=y
))
# 优化器
optimizer = tf.train.AdamOptimizer(learning_rate)
train_opt = optimizer.minimize(loss)
# 计算准确率
correct_pred = tf.equal(tf.argmax(y, axis=1), prediction)
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer())
mnist = input_data.read_data_sets(
'../datas/mnist', one_hot=True, reshape=False
)
# print(mnist.train.num_examples)
step = 1
while True:
# 执行训练
batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)
feed = {x: batch_x, y: batch_y}
_, train_loss, train_acc = sess.run([train_opt, loss, accuracy], feed)
print('Step:{} - Train Loss:{:.5f} - Train acc:{:.5f}'.format(
step, train_loss, train_acc
))
# 持久化
# if step % 100 == 0:
# files = 'model_{:.3f}.mat'.format(train_acc)
# save_file = os.path.join(checkpoint_dir, files)
# store_weights(sess, save_path=save_file)
step += 1
# 退出机制
if train_acc >0.99:
break
if __name__ == '__main__':
train()
08_04基于手写数据集_mat保存模型参数
最新推荐文章于 2022-07-17 17:40:01 发布