使用GitHub上的模型
1. 在https://github.com/tensorflow/models下载model-master压缩包
-
本案例要使用slim文件夹中的train_image_classifier.py(可以将slim文件夹单独copy出来)
-
下载并制作自己的数据集放到创建的H:\Self-Study\PyCharm\test\slim\images目录中,这里做了五种分类,每类中有800个数据,共4000,其中按8:2分为train和test
-
将自己的数据集制作成tfrecord文件。编写generate_tfrecord.py脚本文件,代码如下(高亮部分分别为:指定数据集路径、生成标签文件名、生成tfrecord文件当时犯错的地方):
import tensorflow as tf
import os
import random
import math
import sys
# 验证集数量
_NUM_TEST = 800
# 随机种子
_RANDOM_SEED = 0
# 数据块
_NUM_SHARDS = 5
# 数据集路径
DATASET_DIR = "H:/Self-Study/PyCharm/test/slim/images/"
# 标签文件名字
LABELS_FILENAME = "H:/Self-Study/PyCharm/test/slim/images/labels.txt"
# 定义tfrecord文件的路径+名字
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename) # 获取dataser_dir当前目录