1、下载slim和相关的模型:https://github.com/tensorflow/models/tree/master/research/slim,本文使用inception_v3
2、下载数据,本文使用数据源于kaggle的Plant Seedings Classification比赛
3、复制slim/datasets/download_and_convert_flowers.py,并修改名字为download_and_convert_plantseeding.py。对其进行相应的修改
(1)注释_DATA_URL
(2)修改_NUM_VALIDATION 这个是验证集图像数据的数量,根据实际的数据进行修改
(3)修改_NUM_SHARDS,tfrecord的数量,看到网上有建议1024左右一个tfrecord,大概进行设置一下
(4)修改flower_root的路径,顺便修改flower_root这个变量名,看起来舒服点
(5)因为使用自己的数据集,不需要下载数据所以注释掉dataset_utils.download……那一行
(6)因为没有下载文件,因此也不需要删除下载的文件,因此注释掉_clean_up_....那句
(7)修改_get_dataset_filename函数中的outputfilename
#-*- encoding:utf-8 -*-
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Downloads and converts Flowers data to TFRecords of TF-Example protos.
This module downloads the Flowers data, uncompresses it, reads the files
that make up the Flowers data and creates two TFRecord datasets: one for train
and one for test. Each TFRecord dataset is comprised of a set of TF-Example
protocol buffers, each of which contain a single image and label.
The script should take about a minute to run.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
import random
import sys
import tensorflow as tf
from datasets import dataset_utils
# The URL where the Flowers data can be downloaded.
#_DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
# The number of images in the validation set.
_NUM_VALIDATION = 470 #修改验证集的数量
# Seed for repeatability.
_RANDOM_SEED = 0
# The number of shards per dataset split.
_NUM_SHARDS = 4 #修改tfrecord的数量
class ImageReader(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def read_image_dims(self, sess, image_data):
image = self.decode_jpeg(sess, image_data)
return image.shape[0], image.shape[1]
def decode_jpeg(self, sess, image_data):
image = sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _get_filenames_and_classes(dataset_dir):
"""Returns a list of filenames and inferred class names.
Args:
dataset_dir: A directory containing a set of subdirectories representing
class names. Each subdirectory should contain PNG or JPG encoded images.
Returns:
A list of image file paths, relative to `dataset_dir` and the list of
subdirectories, representing class names.
"""
#flower_root = os.path.join(dataset_dir, 'train_convert')#修改数据集的路径
dataset_root = os.path.join(dataset_dir, 'train_convert')#修改数据集的路径
directories = []
class_names = []
for filename in os.listdir(dataset_root):
path = os.path.join(dataset_root, filename)
if os.path.isdir(path):
directories.append(path)
clas