inceptionv3迁移训练

1:准备图片数据,一份训练数据,一份测试数据。结构如下:
在这里插入图片描述
下载retrain.py程序( https://github.com/tensorflow/hub )在example文件夹下的image-train里面,如果上述链接下载下来的retrain.py训练时候报无法连接的错误,改换使用下面的retrain.py(具体内部改了什么我还没搞清楚)。
https://github.com/zxq201988/deeplearning-code

将下载下来的retrain.py存放到 D:\TensorFlow\retrain\ 下
3:下载inception-v3模型
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
将压缩包存放到 D:\TensorFlow\inception_model 文件夹下,不必解压。
创建批处理命令文件retrain.bat。内容如下:
python E:/TensorFlow/retrain/retrain.py ^ #retrain.py 文件的路径
–bottleneck_dir bottleneck ^ #bottleneck 文件夹的路径 ,默认和 retrain.py 同一个文件夹
–how_many_training_steps 200 ^ #迭代 200 次
–model_dir E:/Tensorflow/inception_model/ ^ #inception-v3 模型的压缩包路径
–output_graph output_graph.pb ^ #输出的模型文件名
–output_labels output_labels.txt ^ #输出的标签
–image_dir E:\TensorFlow\retrain\data\train #自己的训练数据集存放路径
pause

在 D:\TensorFlow\retrain\ 下新建一个名叫 bottleneck 的文件夹,用于存放批处理之后各个图片的.txt文件。
最终目录结构如下图所示:
在这里插入图片描述
准备工作完成,接下来运行retrain.bat文件,就会在命令行训练模型,
训练完成之后就可以使用测试数据测试自己的模型质量了。下面是测试代码,只需要在代码中更改测试数据所在路径即可,在python环境中运行

# coding: utf-8
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt
 
 
lines = tf.gfile.GFile('retrain/output_labels.txt').readlines()
uid_to_human = {}
#一行一行读取数据
for uid,line in enumerate(lines) :
    #去掉换行符
    line=line.strip('\n')
    uid_to_human[uid] = line
 
def id_to_string(node_id):
    if node_id not in uid_to_human:
        return ''
    return uid_to_human[node_id]
 
 
#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('retrain/output_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')
 
 
with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    #遍历目录
    for root,dirs,files in os.walk('data/test/'):  #测试图片存放位置
        for file in files:
            #载入图片
            image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
            predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
            predictions = np.squeeze(predictions)#把结果转为1维数据
 
            #打印图片路径及名称
            image_path = os.path.join(root,file)
            print(image_path)
            #显示图片
            img=Image.open(image_path)
            plt.imshow(img)
            plt.axis('off')
            plt.show()
 
            #排序
            top_k = predictions.argsort()[::-1]
            print(top_k)
            for node_id in top_k:     
                #获取分类名称
                human_string = id_to_string(node_id)
                #获取该分类的置信度
                score = predictions[node_id]
                print('%s (score = %.5f)' % (human_string, score))
            print()
 
 
 

借鉴:https://blog.csdn.net/weixin_38663832/article/details/80555341

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值