Tensorflow 实战Google深度学习框架——学习笔记(六)迁移学习

本文介绍了TensorFlow中的迁移学习步骤,包括获取数据集、加载预训练模型、定义新模型的输入和输出、设计前向传播和损失函数,以及使用Inception-v3模型进行特征提取和模型训练。示例代码来源于《TensorFlow实战Google深度学习框架》6.5.2节。
摘要由CSDN通过智能技术生成

TensorFlow迁移学习的步骤

1、获取(你要训练的)数据集

2、获取模型(bp文件加载)

3、获取迁移后的输出张量bottleneck_tensor和输入张量jpeg_data_tensor(这一步同时也会获取jpeg_data_tensor到bottleneck_tensor的计算模型,这里称为model1)

4、根据bottleneck_tensor的维度定义model2的输入层节点bottleneck_input和model2的输出节点ground_truth_input

5、定义model2的前向传播过程(final_training_ops)

6、定义model2的损失函数,向后传播

7、根据jpeg_data_tensor, bottleneck_tensor在model1中训练得到样本的2048维特征train_bottlenecks,再获得样本的的输出结果train_ground_truth,喂入model2中,训练模型

代码如下:
此例程出自《TensorFlow实战Google深度学习框架》6.5.2小节 卷积神经网络迁移学习。
数据集来自http://download.tensorflow.org/example_images/flower_photos.tgz ,及谷歌提供的Inception-v3模型https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip 。 自行下载和解压。
解压后的文件夹包含5个子文件夹,每个子文件夹的名称为一种花的名称,代表了不同的类别。
工程目录

-transfer_learning
    -flower_data  //存放原始图片的文件夹,有5个子文件夹, 每个子文件夹的名称为一种花的名称
        -daisy   //daisy类花图片的文件夹
        -dandelion
        -roses
        -sunflowers
        -tulips
        -LICENSE.txt
    -model   //存放模型的文件夹
        -imagenet_comp_graph_label_strings.txt
        -LICENSE
        -tensorflow_inception_graph.pb   //模型文件
    -tmp
        -bottleneck  //保存模型瓶颈层的特征结果
            -daisy   //daisy类花特征的文件夹
            -dandelion
            -roses
            -sunflowers
            -tulips
    -transfer_flower.py  //所有的程序都在这里了

transfer_flower.py

import glob
import os.path
import random
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

# Inception-v3模型瓶颈层的节点个数
BOTTLENECK_TENSOR_SIZE = 2048

# Inception-v3模型中代表瓶颈层结果的张量名称。
# 在训练模型时,可以通过tensor.name来获取张量的名称。
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'

# 图像输入张量所对应的名称。
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'

# 下载的谷歌训练好的Inception-v3模型文件目录
MODEL_DIR = 'model/'

# 下载的谷歌训练好的Inception-v3模型文件名
MODEL_FILE = 'tensorflow_inception_graph.pb'

# 因为一个训练数据会被使用多次,所以可以将原始图像通过Inception-v3模型计算
# 得到的特征向量保存在文件中,免去重复的计算。
# 下面的变量定义了这些文件的存放地址。
CACHE_DIR = 'tmp/bottleneck/'

# 图片数据文件夹。
# 在这个文件夹中每一个子文件夹代表一个需要区分的类别,每个子文件夹中存放了对应类别的图片。
INPUT_DATA = 
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值