08_04基于手写数据集_mat保存模型参数


import os
import numpy as np
import tensorflow as tf
from scipy import io
from tensorflow.examples.tutorials.mnist import input_data


# 1、设置超参数
learning_rate = 0.001
epochs = 10
batch_size = 128
test_valid_size = 512  # 用于验证或者测试的样本数量。
n_classes = 10
keep_probab = 0.75


def conv2d_block(input_tensor, filter_w, filter_b, stride=1):
    """
    实现 卷积 +  偏置项相加 + 激活
    :param input_tensor:
    :param filter_w:
    :param filter_b:
    :param stride:
    :return:
    """
    conv = tf.nn.conv2d(
        input=input_tensor, filter=filter_w, strides=[1, stride, stride, 1], padding='SAME'
    )
    conv = tf.nn.bias_add(conv, filter_b)
    conv = tf.nn.relu6(conv)
    return conv


def maxpool(input_tensor, k=2):
    """
    池化
    :param input_tensor:
    :param k:
    :return:
    """
    ksize = [1, k, k, 1]
    strides = [1, k, k, 1]
    max_out = tf.nn.max_pool(
        value=input_tensor, ksize=ksize, strides=strides, padding='SAME'
    )
    return max_out


def model(input_tensor, keep_prob, pre_trained_weights=None):
    """
    :param input_tensor:   输入图片的占位符
    :param weights:
    :param biases:
    :param keep_prob:     保留概率的占位符
    :return:
    """
    """
    'w_conv1:0', 'w_conv2:0', 'w_fc1:0', 'w_logits:0', 
    'b_conv1:0', 'b_conv2:0', 'b_fc1:0', 'b_logits:0']
    """
    if pre_trained_weights:
        W = pre_trained_weights
        weights = {
            'conv1': tf.get_variable('w_conv1', dtype=tf.float32,
                                     initializer=W['w_conv1:0'], trainable=False),
            'conv2': tf.get_variable('w_conv2', dtype=tf.float32,
                                     initializer=W['w_conv2:0'], trainable=False),
            'fc1': tf.get_variable('w_fc1', dtype=tf.float32,
                                   initializer=W['w_fc1:0'], trainable=True),
            'logits': tf.get_variable('w_logits', dtype=tf.float32,
                                      initializer=W['w_logits:0'], trainable=True),
        }
        biases = {
            'conv1': tf.get_variable('b_conv1', dtype=tf.float32,
                                     initializer=np.reshape(W['b_conv1:0'], -1), trainable=False),
            'conv2': tf.get_variable('b_conv2', dtype=tf.float32,
                                     initializer=np.reshape(W['b_conv2:0'], -1), trainable=False),
            'fc1': tf.get_variable('b_fc1', shape=[1024], dtype=tf.float32,
                                   initializer=tf.zeros_initializer()),
            'logits': tf.get_variable('b_logits', shape=[n_classes], dtype=tf.float32,
                                      initializer=tf.zeros_initializer()),
        }
    else:
        weights = {
            'conv1': tf.get_variable('w_conv1', shape=[5, 5, 1, 32], dtype=tf.float32,
                                     initializer=tf.truncated_normal_initializer(stddev=0.1)),
            'conv2': tf.get_variable('w_conv2', shape=[5, 5, 32, 64], dtype=tf.float32,
                                     initializer=tf.truncated_normal_initializer(stddev=0.1)),
            'fc1': tf.get_variable('w_fc1', shape=[7 * 7 * 64, 1024], dtype=tf.float32,
                                   initializer=tf.truncated_normal_initializer(stddev=0.1)),
            'logits': tf.get_variable('w_logits', shape=[1024, n_classes], dtype=tf.float32,
                                      initializer=tf.truncated_normal_initializer(stddev=0.1)),
        }
        biases = {
            'conv1': tf.get_variable('b_conv1', shape=[32], dtype=tf.float32,
                                     initializer=tf.zeros_initializer()),
            'conv2': tf.get_variable('b_conv2', shape=[64], dtype=tf.float32,
                                     initializer=tf.zeros_initializer()),
            'fc1': tf.get_variable('b_fc1', shape=[1024], dtype=tf.float32,
                                   initializer=tf.zeros_initializer()),
            'logits': tf.get_variable('b_logits', shape=[n_classes], dtype=tf.float32,
                                      initializer=tf.zeros_initializer()),
        }

    # 1、卷积1  [N, 28, 28, 1]  ---> [N, 28, 28, 32]
    conv1 = conv2d_block(
        input_tensor=input_tensor, filter_w=weights['conv1'], filter_b=biases['conv1']
    )

    # 2、池化1 [N, 28, 28, 32]   --->[N, 14, 14, 32]
    pool1 = maxpool(conv1, k=2)

    # 3、卷积2  [N, 14, 14, 32]  ---> [N, 14, 14,64]
    conv2 = conv2d_block(
        input_tensor=pool1, filter_w=weights['conv2'], filter_b=biases['conv2']
    )
    conv2 = tf.nn.dropout(conv2, keep_prob=keep_prob)

    # 4、池化1 [N, 14, 14,64]   --->[N, 7, 7, 64]
    pool2 = maxpool(conv2, k=2)

    # 5、拉平层(flatten)    [N, 7, 7, 64]  ---> [N, 7*7*64]
    x_shape = pool2.get_shape()
    flatten_shape = x_shape[1] * x_shape[2] * x_shape[3]
    flatted = tf.reshape(pool2, shape=[-1, flatten_shape])

    # 6、FC1  全连接层
    fc1 = tf.nn.relu6(tf.matmul(flatted, weights['fc1']) + biases['fc1'])
    fc1 = tf.nn.dropout(fc1, keep_prob=keep_prob)

    # 7、logits层
    logits = tf.add(tf.matmul(fc1, weights['logits']), biases['logits'])
    with tf.variable_scope('prediction'):
        prediction = tf.argmax(logits, axis=1)

    return logits, prediction


def create_dir_path(path):
    if not os.path.exists(path):
        os.makedirs(path)
        print('create file path:{}'.format(path))


def store_weights(sess, save_path):
    # todo 1、获取所有需要持久化的变量
    # vars_list = tf.global_variables()
    vars_list = tf.trainable_variables()

    # 2、执行得到变量的值
    vars_values = sess.run(vars_list)

    # todo 3、将变量转换为字典对象
    mdict = {}
    for values, var in zip(vars_values, vars_list):
        # 获取变量的名字
        name = var.name
        # 赋值
        mdict[name] = values
    # todo 4、保存为matlab数据格式
    io.savemat(save_path, mdict)
    print('Saved Vars to files:{}'.format(save_path))


def train():
    # 创建持久化文件夹
    checkpoint_dir = './model/mnist/matlab/ai20'
    create_dir_path(checkpoint_dir)

    graph = tf.Graph()
    with graph.as_default():
        # 1、占位符
        x = tf.placeholder(tf.float32, [None, 28, 28, 1], name='x')
        y = tf.placeholder(tf.float32, [None, 10], name='y')
        keep_prob = tf.placeholder_with_default(0.75, shape=None, name='keep_prob')

        # 2、创建模型图
        weights_path = './model/mnist/matlab/ai20'
        files = os.listdir(weights_path)
        if files:
            weight_file = os.path.join(weights_path, files[0])
        if os.path.isfile(weight_file):
            mdict = io.loadmat(weight_file)
            logits, prediction = model(x, keep_prob, pre_trained_weights=mdict)
            print('Load old model continue to train!')
        else:
            logits, prediction = model(x, keep_prob)
            print('No old model, train from scratch!')

        # 3、损失
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=logits, labels=y
        ))

        # 优化器
        optimizer = tf.train.AdamOptimizer(learning_rate)
        train_opt = optimizer.minimize(loss)

        # 计算准确率
        correct_pred = tf.equal(tf.argmax(y, axis=1), prediction)
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    with tf.Session(graph=graph) as sess:
        sess.run(tf.global_variables_initializer())

        mnist = input_data.read_data_sets(
            '../datas/mnist', one_hot=True, reshape=False
        )

        # print(mnist.train.num_examples)

        step = 1
        while True:
            # 执行训练
            batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)
            feed = {x: batch_x, y: batch_y}
            _, train_loss, train_acc = sess.run([train_opt, loss, accuracy], feed)
            print('Step:{} - Train Loss:{:.5f} - Train acc:{:.5f}'.format(
                step, train_loss, train_acc
            ))

            # 持久化
            # if step % 100 == 0:
            #     files = 'model_{:.3f}.mat'.format(train_acc)
            #     save_file = os.path.join(checkpoint_dir, files)
            #     store_weights(sess, save_path=save_file)
            step += 1

            # 退出机制
            if train_acc >0.99:
                break


if __name__ == '__main__':
    train()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值