实战Google深度学习框架Tensorflow——迁移学习

迁移学习将已经训练好的模型导入并使用,提高了训练效率。
以下为将原始图像数据整理为模型需要的输入数据实例。
首先进行数据预处理:

#-*- coding: utf-8 -*-
import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

INPUT_DATA = "E:/train-data/Flower_Photo/flower_photos"
OUTPUT_FILE = 'E:/train-data/Pro_Flower_Photo'
#测试数据与验证数据比例
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10
#读取数据并分类数据
def create_image_lists(sess, testing_percentage, validation_percentage):
    sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]#遍历文件下子文件
    
    is_root_dir = True

    training_images = []
    training_labels = []
    testing_images = []
    testing_labels = []
    validation_images = []
    validation_labels = []
    current_label = 0

    for sub_dir in sub_dirs:
        if is_root_dir:
            is_root_dir = False
            continue

        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        file_list = []
        dir_name = os.path.basename(sub_dir)#获取文件名
        for extension in extensions:
            file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)#把目录和文件名合成一个路径
            file_list.extend(glob.glob(file_glob))#glob用于从目录通配符搜索中生成文件列表
        if not file_list: continue

        for file_name in file_list:
            #读取并解析图片,将图片转化为299*299以便inception-v3模型处理
            #读取图像,参数:1、图像路径;2、读取方式
            image_raw_data = gfile.FastGFile(file_name, 'rb').read()
            image = tf.image.decode_jpeg(image_raw_data)
            if image.dtype != tf.float32:
                image = tf.image.convert_image_dtype(
                    image, dtype=tf.float32
                )
            image = tf.image.resize_images(image,[299,299])
            image_value = sess.run(image)
            #随机划分数据集
            chance = np.random.randint(100)
            if chance < validation_percentage:
                validation_images.append(image_value)
                validation_labels.append(current_label)
            elif chance < (testing_percentage + validation_percentage):
                testing_images.append(image_value)
                testing_images.append(current_label)
            else:
                training_images.append(image_value)
                training_labels.append(current_label)
        current_label+=1

    #打乱数据
    state = np.random.get_state()
    np.random.shuffle(training_images)
    np.set_state(state)
    np.random.shuffle(training_labels)
    return np.asarray([training_images, training_labels,
                       validation_images, validation_labels,
                       testing_images, testing_labels])

def main():
    with tf.Session() as sess:
        processed_data = create_image_lists(
            sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE
        )
        #通过numpy格式保存处理后的数据
        np.save(OUTPUT_FILE, processed_data)

if __name__ =='__main__':
    main()

np.asarray 浅拷贝c=np.asarray(a),修改a改变c的值
数据处理之后通过下载好的模型,可进行迁移学习。

import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
import tensorflow.contrib.slim as slim

import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3

INPUT_DATA = 'E:/train-data/Pro_Flower_Photo'
#训练之后存放模型的路径
TRAIN_FILE = "E:/model/inception-v3"
#下载的模型
CKPT_FILE = "E:/model/inception-v3"

LEARNING_RATE = 0.0001
STEPS = 300
BATCH = 32
N_CLASSES = 5
#不需要从模型中加载的参数。这里就是最后的全连接层,因为要重新训练这一层的参数
#以下给出的是参数的前缀
CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits, InceptionV3/AuxLogits'
#需要训练的网络层参数
TRAINABLE_SCOPES = 'InceptionV3/Logits, InceptionV3/AuxLogits'
#从模型中获取需要的参数
def get_tuned_variables():
    exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')]
    variables_to_restore = []

    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startwith(exclusion):
                excluded = True
                break
                
        if not excluded:
            variables_to_restore.append(var)
    return variables_to_restore
#获取需要训练的变量
def get_trainable_variables():
    scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(',')]
    variables_to_train = []
    for scope in scopes:
        variables = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        variables_to_train.extend(variables)
    return variables_to_train

def main():

    processed_data = np.load(INPUT_DATA)
    #...
    images = tf.placeholder(tf.float32, [None, 299, 299, 3], name='input_images')

    labels = tf.placeholder(tf.int64, [None], name='labels')

with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
    logits, _ = inception_v3.inception_v3(
        images, num_classes=N_CLASSES
    )
    trainable_variables = get_trainable_variables()
    tf.losses.softmax_cross_entropy(
        tf.one_hot(labels, N_CLASSES), logits, weights=1.0
    )

    train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss)
    #计算正确率
    with tf.name_scope('evaluation'):
        correction_rate = tf.equal(tf.arg_max(logits,1),labels)
        evaluation_step = tf.reduce_mean(tf.cast(correction_rate, tf.float32))

    #定义加载模型的函数
    load_fn = slim.assign_from_checkpoint_fn(
        CKPT_FILE,
        get_tuned_variables(),
        ignore_missing_vars=True
    )
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)

        print('Loading tuned variables from %s '% CKPT_FILE)
        load_fn(sess)#加载训练好的模型

        start = 0
        end = BATCH
        for i in range(STEPS):
            sess.run(train_step, feed_dict={
            images: training_images[start:end]
            labels: training_labels[start:end]}








            })







  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值