利用google的inception3重训练自己的图像识别模型(迁移学习)

程序运行环境是tensorflow-cpu1.13.1

1.准备数据集

准备几个类别的图像数据集存放在各自类别路径下
在这里插入图片描述
如上图,将五个类别的图像数据分别存放在各自的文件目录,这里每个类别存放了500张.jpg图像文件,命名是0001.jpg - 0500.jpg

2.下载inception_model

这里需要将谷歌的inception_model文件放置在程序相同的路径下,可以从这里https://download.csdn.net/download/cyj5201314/16581511 下载model
在这里插入图片描述

3.下载retrain.py

可以从这里下载 https://download.csdn.net/download/cyj5201314/16603267

4. 指定训练参数

  • 训练集的路径,我这里是存放在程序当前路径的data目录里,给该参数的default赋值路径字符串即可
    在这里插入图片描述
  • 训练结束后图模型保存路径,我这里直接保存在程序当前路径下
    在这里插入图片描述
  • 输出的标签存储路径,这里同样直接保存在当前程序路径
    在这里插入图片描述
  • 一共训练多少步, 这里指定200步
    在这里插入图片描述
  • 卷积层最终输出的张量保存路径,这里实际上只训练最后一个全连接层,卷积层全部使用谷歌训练好的参数,所以这里相当于计算全部训练图像的卷积层输出,保存到本地
    在这里插入图片描述
  • inception3模型路径,这里将指定第二步下载好的模型路径,这里直接放在程序当前路径
    在这里插入图片描述

5训练模型

直接运行retrain.py即可
计算全部训练集的卷积层输出如下: 这里计算全部输出大概需要半个小时
在这里插入图片描述
训练结束后验证准确率和测试准确率如下: 训练速度很快,因为只训练最后的一个全连接层
在这里插入图片描述
在508个测试图像上的准确率达到100%

6.用模型识别图像

将要识别的图像放在程序当前路径下的images文件目录
在这里插入图片描述
运行predict.py即可


import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt

lines = tf.gfile.GFile('output_labels.txt').readlines()
uid_to_human = {}

# 一行一行读取数据
for uid,line in enumerate(lines) :
    #去掉换行符
    line=line.strip('\n')
    uid_to_human[uid] = line

# 分类编号变成描述
def id_to_string(node_id):
    if node_id not in uid_to_human:
        return ''
    return uid_to_human[node_id]


# 创建一个图来存放训练好的模型
with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
    # final_result为输出tensor的名字
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    # 遍历目录
    for root,dirs,files in os.walk('images/'):
        for file in files:
            # 载入图片
            image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
            # 把图像数据传入模型获得模型输出结果
            predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})
            # 把结果转为1维数据
            predictions = np.squeeze(predictions)
            # 打印图片路径及名称
            image_path = os.path.join(root,file)
            print(image_path)
            # 显示图片
            img = Image.open(image_path)
            plt.imshow(img)
            plt.axis('off')
            plt.show()

            # 排序
            top_k = predictions.argsort()[::-1]
            for node_id in top_k:     
                # 获取分类名称
                human_string = id_to_string(node_id)
                # 获取该分类的置信度
                score = predictions[node_id]
                print('%s (score = %.5f)' % (human_string, score))
            print()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
训练和测试集数据可在 https://download.csdn.net/download/cyj5201314/16604481 下载

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值