08_03基于matlab_VGG的17花_迁移学习



import os
import numpy as np
import tensorflow as tf
from scipy import io
from tflearn.datasets import oxflower17

"""
基于数据集17类花,基于预训练VGG19模型进行迁移学习。

# mat 是值MATLAB 的数据特征存储
权重地址: http://www.vlfeat.org/matconvnet/pretrained/#pretrained-models
"""

# 定义外部传入参数设置
tf.app.flags.DEFINE_string(
    'checkpoint_dir', './model/oxflower17', '模型持久化文件路径,默认为:./model/oxflower17'
)
tf.app.flags.DEFINE_integer(
    'batch_size', 4, '批量的大小,默认为:2'
)
tf.app.flags.DEFINE_float(
    'learning_rate', 0.0001, '学习率,默认为:0.001'
)
tf.app.flags.DEFINE_bool(
    'is_train', True, '给定模型是否训练操作,True表示训练,False表示预测'
)
FLAGS = tf.app.flags.FLAGS


def create_dir_path(path):
    if not os.path.exists(path):
        os.makedirs(path)


def get_weights_biases(vgg_layers, i):
    """
    创建变量(使用预训练的权重创建)
    :param vgg_layers:
    :param i:
    :return:
    """
    weights = vgg_layers[i][0][0][2][0][0]
    biases = vgg_layers[i][0][0][2][0][1]
    # 创建相应的变量
    weight = tf.Variable(initial_value=weights, dtype=tf.float32, trainable=False)
    bias = tf.Variable(initial_value=np.reshape(biases, -1), dtype=tf.float32, trainable=False)
    return weight, bias

def build_net(ntype, input, weights_bias=None):
    """
    :param ntype:
    :param input:
    :param weights_bias:
    :return:
    """
    if ntype == 'conv':
        net = tf.nn.conv2d(input, weights_bias[0], strides=[1,1,1,1], padding='SAME')
        net = tf.nn.bias_add(net, weights_bias[1])
        return tf.nn.relu(net)
    else:
        # 池化操作
        return tf.nn.max_pool(
            value=input, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME'
        )


def model_vgg(input_x, vgg_raw_weights, num_classes=17):
    # 加载权重。
    vgg_layers = vgg_raw_weights['layers'][0]

    with tf.variable_scope('Network'):
        # 设置一些参数(将各个层 添加为这个字典的value值)
        net = {}
        with tf.variable_scope('input'):
            net['input'] = input_x

        with tf.variable_scope('Conv1_1'):
            net['Conv1_1'] = build_net('conv', net['input'], get_weights_biases(vgg_layers, 0))
        with tf.variable_scope('Conv1_2'):
            net['Conv1_2'] = build_net('conv', net['Conv1_1'], get_weights_biases(vgg_layers, 2))
        # [N, 224, 224, 64]
        with tf.variable_scope('Pooling1'):
            net['Pooling1'] = build_net('pool', net['Conv1_2'])
            # [N, 112, 112, 64]

        with tf.variable_scope('Conv2_1'):
            net['Conv2_1'] = build_net('conv', net['Pooling1'], get_weights_biases(vgg_layers, 5))
        with tf.variable_scope('Conv2_2'):
            net['Conv2_2'] = build_net('conv', net['Conv2_1'], get_weights_biases(vgg_layers, 7))
            # [N, 112, 112, 128]
        with tf.variable_scope('Pooling2'):
            net['Pooling2'] = build_net('pool', net['Conv2_2'])
            # [N, 56, 56, 128]

        with tf.variable_scope('Conv3_1'):
            net['Conv3_1'] = build_net('conv', net['Pooling2'], get_weights_biases(vgg_layers, 10))
        with tf.variable_scope('Conv3_2'):
            net['Conv3_2'] = build_net('conv', net['Conv3_1'], get_weights_biases(vgg_layers, 12))
        with tf.variable_scope('Conv3_3'):
            net['Conv3_3'] = build_net('conv', net['Conv3_2'], get_weights_biases(vgg_layers, 14))
        with tf.variable_scope('Conv3_4'):
            net['Conv3_4'] = build_net('conv', net['Conv3_3'], get_weights_biases(vgg_layers, 16))
            # [N, 56, 56, 256]
        with tf.variable_scope('Pooling3'):
            net['Pooling3'] = build_net('pool', net['Conv3_4'])
            # [N, 28, 28, 256]

        with tf.variable_scope('Conv4_1'):
            net['Conv4_1'] = build_net('conv', net['Pooling3'], get_weights_biases(vgg_layers, 19))
        with tf.variable_scope('Conv4_2'):
            net['Conv4_2'] = build_net('conv', net['Conv4_1'], get_weights_biases(vgg_layers, 21))
        with tf.variable_scope('Conv4_3'):
            net['Conv4_3'] = build_net('conv', net['Conv4_2'], get_weights_biases(vgg_layers, 23))
        with tf.variable_scope('Conv4_4'):
            net['Conv4_4'] = build_net('conv', net['Conv4_3'], get_weights_biases(vgg_layers, 25))
            # [N, 28, 28, 512]
        with tf.variable_scope('Pooling4'):
            net['Pooling4'] = build_net('pool', net['Conv4_4'])
            # [N, 14, 14, 512]

        # 后面自己构建全连接网络,做分类用
        with tf.variable_scope('FC', initializer=tf.truncated_normal_initializer(stddev=0.1)):
            x = net['Pooling4']
            with tf.variable_scope('fc1'):
                shape = x.get_shape()
                flatten_shape = shape[1] * shape[2] * shape[3]
                x = tf.reshape(x, shape=[-1, flatten_shape])
                w = tf.get_variable('w', shape=[flatten_shape, 500])
                b = tf.get_variable('b', shape=[500], initializer=tf.zeros_initializer())
                x = tf.matmul(x, w) + b
                x = tf.nn.relu(x)

            with tf.variable_scope('logits'):
                w = tf.get_variable('w', shape=[500, num_classes])
                b = tf.get_variable('b', shape=[num_classes], initializer=tf.zeros_initializer())
                logits = tf.matmul(x, w) + b

            with tf.variable_scope('predictions'):
                prediction = tf.argmax(logits, axis=1)
    return logits, prediction

def create_loss(labels, logits):
    """
    损失函数
    :param labels:
    :param logits:
    :return:
    """
    with tf.name_scope('loss'):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=logits, labels=labels
        ))
        return loss


def create_optimizer(loss, lr=0.001):
    """
    创建优化器
    :param loss:
    :param lr:
    :return:
    """
    with tf.name_scope('loss'):
        optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        train_opt = optimizer.minimize(loss=loss)
        return train_opt

def create_accuracy(labels, predictions):
    """
    计算准确率
    :param labels:
    :param predictions:
    :return:
    """
    with tf.name_scope('accuracy'):
        y_true_labels = tf.argmax(labels, axis=1)
        # 计算准确率
        correct_pred = tf.equal(y_true_labels, predictions)
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
        return accuracy


def train():
    # 创建文件夹
    create_dir_path(FLAGS.checkpoint_dir)

    graph = tf.Graph()
    with graph.as_default():
        # 1、占位符
        x = tf.placeholder(tf.float32, shape=[None, 224, 224, 3], name='x')
        y = tf.placeholder(tf.float32, shape=[None, 17], name='y')

        # 2、加载预训练的权重
        data_path = '../datas/vgg/imagenet-vgg-verydeep-19.mat'
        mdict = io.loadmat(data_path)

        # 3、创建模型
        logits, prediction = model_vgg(input_x=x, vgg_raw_weights=mdict)
        # 4、损失函数,
        loss = create_loss(labels=y, logits=logits)
        # 5、优化器
        train_opt = create_optimizer(loss, FLAGS.learning_rate)
        # 6、准确率
        accuracy = create_accuracy(y, prediction)

        saver = tf.train.Saver(max_to_keep=1)
    with tf.Session(graph=graph) as sess:
        sess.run(tf.global_variables_initializer())

        # 加载训练数据
        X, Y = oxflower17.load_data(
            dirname='../datas/17flower', resize_pics=(224, 224), shuffle=True, one_hot=True
        )
        total_sample = X.shape[0]
        total_batch = total_sample // FLAGS.batch_size

        # 定义一个可视化
        log_dir = './model/oxflower17/graph'
        writer = tf.summary.FileWriter(logdir=log_dir, graph=sess.graph)
        step = 1
        while True:
            train_acc = 0.0
            random_index = np.random.permutation(total_sample)
            for batch in range(total_batch):
                start_idx = batch * FLAGS.batch_size
                end_idx = start_idx + FLAGS.batch_size
                batch_index = random_index[start_idx: end_idx]

                batch_x = X[batch_index]
                batch_y = Y[batch_index]
                # 模型训练
                feed = {x: batch_x, y: batch_y}
                _, train_loss, train_acc = sess.run([train_opt, loss, accuracy], feed)
                print('Step:{} - Train loss:{} - train acc:{}'.format(step, train_loss, train_acc))
                step += 1

            # 模型持久化
            if train_acc > 0.93:
                files = 'model.ckpt'
                save_files = os.path.join(FLAGS.checkpoint_dir, files)
                saver.save(sess, save_path=save_files, global_step=step)
                break

def predicts():
    pass


def main(_):
    if FLAGS.is_train:
        train()
    else:
        predicts()


if __name__ == '__main__':
    tf.app.run()













  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值