使用tensorflow slim自带的mobilenet_v2模型训练自己的数据

 

目录

简介:

1、数据集制作

2、slim修改及训练

训练:

3、模型导出

使用官方bazel模型导出:

使用tensorflow模块功能导出


 

简介:

本文将记录分类样本如何制作为tfrecord格式,已经如何用tensorflow的slim模块训练分类模型,把模型固化导出。

 

环境准备:

  • python 3.5
  • tensorflow-gpu 1.10
  • models-master  将整个models工程下载下来。

 

1、数据集制作

样本准备:

我这里准备了天干样本,建立了甲、乙、丙、丁、戊、己、庚、辛、壬、癸10个类别,(类别只为演示用),样本图片为.jpg后缀图片,图片名称及大小无限制, 需要根据类别存储到对应文件夹内。

样本分类如下图:

在 <安装目录>\models-master\research\slim\datasets\ 文件夹内,建立一个新的convert_mydataset.py文件,文件全部内容如下

#coding=utf-8
 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import math
import os
import random
import sys
 
import tensorflow as tf
 
from datasets import dataset_utils
 
# The URL where the Flowers data can be downloaded.
_DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
 
# The number of images in the validation set.
_NUM_VALIDATION = 350
 
# Seed for repeatability.
_RANDOM_SEED = 0
 
# The number of shards per dataset split.
_NUM_SHARDS = 5

subname = ['train.txt', 'validation.txt', 'labels.txt']
 
class ImageReader(object):
  """Helper class that provides TensorFlow image coding utilities."""
 
  def __init__(self):
    # Initializes function that decodes RGB JPEG data.
    self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
    self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
 
  def read_image_dims(self, sess, image_data):
    image = self.decode_jpeg(sess, image_data)
    return image.shape[0], image.shape[1]
 
  def decode_jpeg(self, sess, image_data):
    image = sess.run(self._decode_jpeg,
                     feed_dict={self._decode_jpeg_data: image_data})
    assert len(image.shape) == 3
    assert image.shape[2] == 3
    return image
 
 
def _get_filenames_and_classes(dataset_dir):
  """Returns a list of filenames and inferred class names.
  Args:
    dataset_dir: A directory containing a set of subdirectories representing
      class names. Each subdirectory should contain PNG or JPG encoded images.
  Returns:
    A list of image file paths, relative to `dataset_dir` and the list of
    subdirectories, representing class names.
  """
 
  #改为自己的数据集
  flower_root = os.path.join(dataset_dir, 'fruit_photos')
  directories = []
  class_names = []
  for filename in os.listdir(flower_root):
    path = os.path.join(flower_root, filename)
    if os.path.isdir(path):
      directories.append(path)
      class_names.append(filename)
 
  photo_filenames = []
  for directory in directories:
    for filename in os.listdir(directory):
      path = os.path.join(directory, filename)
      photo_filenames.append(path)
 
  return photo_filenames, sorted(class_names)
 
 
def _get_dataset_filename(dataset_dir, split_name, shard_id):
  #修改为fruit
  output_filename = 'mydataset_%s_%05d-of-%05d.tfrecord' % (
      split_name, shard_id, _NUM_SHARDS)
  return os.path.join(dataset_dir, output_filename)
 
 
def _convert_dataset(txtname, split_name, dataset_dir):
    """
    Converts the given filenames to a TFRecord dataset.
    Args:
    split_name: The name of the dataset, either 'train' or 'validation'.
    filenames: A list of absolute paths to png or jpg images.
    class_names_to_ids: A dictionary from class names (strings) to ids
      (integers).
    dataset_dir: The directory where the converted datasets are stored.
    """
    # 加载文件,仅获取一个label
    images_list, labels_list=load_labels_file(txtname,1)
    num_per_shard = int(math.ceil(len(images_list) / float(_NUM_SHARDS)))
    
    with tf.Graph().as_default():
        image_reader = ImageReader()
        
        with tf.Session('') as sess:
            for shard_id in range(_NUM_SHARDS):
                #record filename
                output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
                
                with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
                    
                    start_ndx = shard_id * num_per_shard
                    end_ndx = min((shard_id+1) * num_per_shard, len(images_list))
                    
                    for i in range(start_ndx, end_ndx):
                        sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1, len(images_list), shard_id))
                        sys.stdout.flush()
                        
                        filename = os.path.join(dataset_dir,images_list[i])
                        # Read the filename:
                        print(filename)
                        image_data = tf.gfile.FastGFile(filename, 'rb').read()
                        height, width = image_reader.read_image_dims(sess, image_data)
                        
                        class_id = labels_list[i]
                        example = dataset_utils.image_to_tfexample(image_data, b'jpg', height, width, class_id)
                        tfrecord_writer.write(example.SerializeToString())
    sys.stdout.write('\n')
    sys.stdout.flush()
 
 
def load_labels_file(filename,labels_num=1,shuffle=False):
    '''
    载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签2,如:test_image/1.jpg 0 2
    :param filename:
    :param labels_num :labels个数
    :param shuffle :是否打乱顺序
    :return:images type->list
    :return:labels type->list
    '''
    images=[]
    labels=[]
    with open(filename) as f:
        lines_list=f.readlines()
        if shuffle:
            random.shuffle(lines_list)
 
        for lines in lines_list:
            line=lines.rstrip().split(' ')
            label=[]
            for i in range(labels_num):
                label.append(int(line[i+1]))
            images.append(line[0])
            labels.append(label)
    return images,labels


def make_train_val_label_txt(ori_dir, frate = 0.8):
    
    # before generate, delete file if exist
    for filename in subname:
        filedir = os.path.join(ori_dir, filename)
        if os.path.exists(filedir):
            os.remove(filedir)
        
    #find sub class fordler
    directories = []
    class_names = []
    path_list = os.listdir(ori_dir)
    path_list.sort()
    for filename in path_list:
        path = os.path.join(ori_dir, filename)
        if os.path.isdir(path):
            directories.append(path)
            class_names.append(filename)
    print("class_names=\n", class_names) 
    
    traindir = os.path.join(ori_dir, subname[0])
    valdir = os.path.join(ori_dir, subname[1])
    labeldir = os.path.join(ori_dir, subname[2])
    
    with open(labeldir,'a+') as f:
        for i, classname in enumerate(class_names):        
             f.write('%s\n' % (classname))
    
    for i, directory in enumerate(directories):
        filenames = []
        for filename in os.listdir(directory):
            filenames.append(filename)
        random.shuffle(filenames)
        left = round(len(filenames) *frate+0.5) 
        
        trainname = filenames[:left]
        trainname.sort()        
        with open(traindir,'a+') as f:
            for name in trainname:
                f.write('%s/%s %d\n'%(class_names[i], name, i))
                
        valname = filenames[left:]
        valname.sort()
        with open(valdir,'a+') as f:
            for name in valname:
                f.write('%s/%s %d\n'%(class_names[i], name, i)) 

def run(dataset_dir):
    """Runs the download and conversion operation.
    Args:
      dataset_dir: The dataset directory where the dataset is stored.
    """
    
    #make train val text and label text
    make_train_val_label_txt(dataset_dir)
    
    
    # Get the train and val txt fullname:
    train_txt = os.path.join(dataset_dir, subname[0])
    val_txt = os.path.join(dataset_dir, subname[1])
 
    # convert the training and validation sets.
    _convert_dataset(train_txt, 'train', dataset_dir)
    _convert_dataset(val_txt, 'validation',dataset_dir)
    
    print('\nFinished converting the mydataset dataset!')

在slim文件夹下打开download_and_convert_data.py 文件,添加:

from datasets import convert_mydataset

再在def main(_): 函数中添加

  elif FLAGS.dataset_name == 'mydataset':
    convert_mydataset.run(FLAGS.dataset_dir)

然后再命令行执行:

python download_and_convert_data.py --dataset_name=mydataset --dataset_dir="E:\样本_天干"

没有报错的话,会在样本目录内生成系列.tfrecord文件和train.txt,validation.txt,和labels.txt文件。

 

2、slim修改及训练

  • 在 <安装目录>\models-master\research\slim\datasets\文件夹内,找到flowers.py复制并重命名为mydataset.py ,

将  _FILE_PATTERN = 'flowers_%s_*.tfrecord'  改为: _FILE_PATTERN = 'mydataset_%s_*.tfrecord'。

将  SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}  改为:  SPLITS_TO_SIZES ='train': 431, 'validation': 102}其中,train代表训练的图片张数,validation代表验证使用的图片张数,数量要和自己的数据集数量对应,使用本文方法可以通过查看第一步生成的train.txt,validation.txt里面的行数确定。

将  _NUM_CLASSES = 5  改为: _NUM_CLASSES =10 为实际样本类别数

 

  • 在<安装目录>\models-master\research\slim\datasets\文件内打开dataset_factory.py,进行如下修改:

添加:

from datasets import mydataset

datasets_map 字典中添加:

'mydataset': mydataset,

 

训练:

由于训练时文件路径不能包含中文,把1生成的.tfrecord后缀的文件全部拷贝到 e:\HeavenlyStems目录里,训练命令如下:

python3 train_image_classifier.py \
  --train_dir=e:\log  \   #训练文件保存路径
  --dataset_dir=e:\HeavenlyStems  \  #样本存放路径
  --dataset_name=mydataset \   #样本名称
  --dataset_split_name=train \ 
  --model_name="mobilenet_v2_140" \  #模型名称
  --checkpoint_path=e:\mobilenet_v2_1.4_224\mobilenet_v2_1.4_224.ckpt \
  --checkpoint_exclude_scopes=MobilenetV2/Logits,MobilenetV2/Predictions,MobilenetV2/predics \
  --trainable_scopes=MobilenetV2/Logits,MobilenetV2/Predictions,MobilenetV2/predics \
  --max_number_of_steps=20000 \  #迭代次数
  --preprocessing_name="inception_v2"
  --learning_rate=0.045 \
  --label_smoothing=0.1 \
  --moving_average_decay=0.9999 \
  --batch_size=32  \
  --num_clones=1  \
  --learning_rate_decay_factor=0.98 \
  --num_epochs_per_decay=2.5 

3、模型导出

模型导出提供2种方法:

使用官方bazel模型导出:

分两步

第一步: Exporting the Inference Graph

python export_inference_graph.py \
  --alsologtostderr \
  --dataset_dir=e:\HeavenlyStems \
  --dataset_name=mydataset \
  --model_name=mobilenet_v2_140\
  --image_size=224 \
  --output_file=e:\log\mobilenet_v2_244.pb

第二步:Freezing the exported Graph

需要先下载tensorflow源码,并安装对应版本的bazel,在tensorflow源码文件夹内执行命令:

bazel build tensorflow/python/tools:freeze_graph

编译需要等待一段时间,编译成功后,在编译目录执行

bazel-bin/tensorflow/python/tools/freeze_graph  
   --input_graph=e:\log\mobilenet_v2_244.pb \
   --input_checkpoint=e:\log\model.ckpt-20000  
   --output_graph=e:\log\mobilenet_v2_1.0_224_frozen.pb   
   --input_binary=True  
   --output_node_name=MobilenetV2/Predictions/Reshape_1 

使用tensorflow模块功能导出

执行如下命令:

python3 -m tensorflow.python.tools.freeze_graph \
  --input_graph e:\log\graph.pbtxt \
  --input_checkpoint e:\log\model.ckpt-40856 \
  --input_binary false \
  --output_graph e:\log\mobilenet_v2_frozen.pb \
  --output_node_names MobilenetV2/Predictions/Reshape_1

参考:

tensorflow深度学习实战笔记(一):使用tensorflow slim自带的模型训练自己的数据

 

  • 2
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论
TensorFlow 2.3.0中,tf_slim库已被弃用,因此无法直接使用`from tf_slim.nets import inception_resnet_v2`来引用inception_resnet_v2。但是,您可以使用TensorFlow官方的模型库(tensorflow/models)中的相应模型来代替。 首先,您需要从GitHub上克隆tensorflow/models仓库到本地: ``` git clone https://github.com/tensorflow/models.git ``` 然后,将models/research/slim目录添加到您的Python路径中。您可以通过以下方式实现: ```python import sys sys.path.append('/path/to/models/research/slim') ``` 现在,您可以使用官方模型库中的inception_resnet_v2模型了。示例代码如下: ```python import tensorflow as tf from official.vision.image_classification import imagenet_preprocessing from official.vision.image_classification import resnet_preprocessing # 导入inception_resnet_v2模型 from official.vision.image_classification.resnet import inception_resnet_v2 # 创建模型实例 model = inception_resnet_v2.InceptionResNetV2(weights=None) # 加载预训练权重(如果有的话) model.load_weights('path/to/pretrained/weights.h5') # 预处理输入图像 image_path = 'path/to/image.jpg' image = tf.io.read_file(image_path) image = tf.image.decode_jpeg(image, channels=3) image = resnet_preprocessing.preprocess_image(image, model.input_shape[1], model.input_shape[2]) image = tf.expand_dims(image, axis=0) # 进行推理 predictions = model.predict(image) # 打印预测结果 print(predictions) ``` 请确保您已经安装了所需的依赖项,并将路径替换为适当的路径。这样,您就可以在TensorFlow 2.3.0中使用inception_resnet_v2模型了。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

蜡笔小心点

你的鼓励是我创造的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值