SSD-Tensorflow:利用KITTI数据集进行训练

感想

原作者的github上提供了两个版本的SSD,一个是SDC-Vehicle-Detection,另一个是SSD-Tensorflow。我在实验中发现,SDC版本好久没更新了,并且用的是tensorflow 0.12,我跑了一次,发现损失莫名其妙的变为0了,最后就放弃了,然后我把KITTI转换成tfrecords的那部分代码移植到了SSD-Tensorflow版本上,经过几次修改操作终于能跑通了,这下不用转换成voc格式的数据进行训练了,直接把kitti数据转换成tfrecords,然后训练即可,亲测可行

操作

把kitti.py,kitti_common.py,kitti_to_tfrecords.py放在SSD-tensorflow目录的datasets目录下
我的kitti.py的代码为:
# Copyright 2015 Paul Balanca. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""KITTI dataset.
"""
import os

import tensorflow as tf

from datasets import dataset_utils
from datasets.kitti_common import KITTI_LABELS, NUM_CLASSES, KITTI_DONTCARE

slim = tf.contrib.slim

# FILE_PATTERN = 'kitti_%s.tfrecord'
FILE_PATTERN = 'kitti_%s_*.tfrecord'
ITEMS_TO_DESCRIPTIONS = {
    'image': 'A color image of varying height and width.',
    'shape': 'Shape of the image',
    'object/bbox': 'A list of bounding boxes, one per each object.',
    'object/label': 'A list of labels, one per each object.',
}
SPLITS_TO_SIZES = {
    'train': 7481,
    'test': 7518,
}


def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
    """Gets a dataset tuple with instructions for reading Pascal VOC dataset.

    Args:
      split_name: A train/test split name.
      dataset_dir: The base directory of the dataset sources.
      file_pattern: The file pattern to use when matching the dataset sources.
        It is assumed that the pattern contains a '%s' string so that the split
        name can be inserted.
      reader: The TensorFlow reader type.

    Returns:
      A `Dataset` namedtuple.

    Raises:
        ValueError: if `split_name` is not a valid train/test split.
    """
    if not file_pattern:
        file_pattern = FILE_PATTERN
    if split_name not in SPLITS_TO_SIZES:
        raise ValueError('split name %s was not recognized.' % split_name)
    file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

    # Allowing None in the signature so that dataset_factory can use the default.
    if reader is None:
        reader = tf.TFRecordReader
    # Features in Pascal VOC TFRecords.
    keys_to_features = {
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
        'image/height': tf.FixedLenFeature([1], tf.int64),
        'image/width': tf.FixedLenFeature([1], tf.int64),
        'image/channels': tf.FixedLenFeature([1], tf.int64),
        'image/shape': tf.FixedLenFeature([3], tf.int64),
        'object/label': tf.VarLenFeature(dtype=tf.int64),
        'object/truncated': tf.VarLenFeature(dtype=tf.float32),
        'object/occluded': tf.VarLenFeature(dtype=tf.int64),
        'object/alpha': tf.VarLenFeature(dtype=tf.float32),
        'object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'object/dimensions/height': tf.VarLenFeature(dtype=tf.float32),
        'object/dimensions/width': tf.VarLenFeature(dtype=tf.float32),
        'object/dimensions/length': tf.VarLenFeature(dtype=tf.float32),
        'object/location/x': tf.VarLenFeature(dtype=tf.float32),
        'object/location/y': tf.VarLenFeature(dtype=tf.float32),
        'object/location/z': tf.VarLenFeature(dtype=tf.float32),
        'object/rotation_y': tf.VarLenFeature(dtype=tf.float32),
    }
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
        'shape': slim.tfexample_decoder.Tensor('image/shape'),
        'object/bbox': slim.tfexample_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'object/bbox/'),
        'object/label': slim.tfexample_decoder.Tensor('object/label'),
    }
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)

    labels_to_names = None
    if dataset_utils.has_labels(dataset_dir):
        labels_to_names = dataset_utils.read_label_file(dataset_dir)
    # else:
    #     labels_to_names = create_readable_names_for_imagenet_labels()
    #     dataset_utils.write_label_file(labels_to_names, dataset_dir)

    return slim.dataset.Dataset(
            data_sources=file_pattern,
            reader=reader,
            decoder=decoder,
            num_samples=SPLITS_TO_SIZES[split_name],
            items_to_descriptions=ITEMS_TO_DESCRIPTIONS,
            num_classes=NUM_CLASSES,
            labels_to_names=labels_to_names)

我的kitti_common.py为:我把类别换成了自己的类别,当然你们也可以就用默认的
# Copyright 2015 Paul Balanca. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import tensorflow as tf
'''
KITTI_LABELS = {
    'none': (0, 'Background'),
    'Car': (1, 'Vehicle'),
    'Van': (2, 'Vehicle'),
    'Truck': (3, 'Vehicle'),
    'Cyclist': (4, 'Vehicle'),
    'Pedestrian': (5, 'Person'),
    'Person_sitting': (6, 'Person'),
    'Tram': (7, 'Vehicle'),
    'Misc': (8, 'Misc'),
    'DontCare': (9, 'DontCare'),
}
'''
KITTI_LABELS = {
    'none': (0, 'Background'),
    'Person': (1, 'Person'),
    'Car': (2, 'Car'),

}
KITTI_DONTCARE = 9
NUM_CLASSES = 8
kitti_to_tfrecords.py的代码为:
# Copyright 2015 Paul Balanca. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts KITTI data to TFRecords file format with Example protos.

The raw Pascal VOC data set is expected to reside in JPEG files located in the
directory 'image_2'. Similarly, bounding box annotations are supposed to be
stored in the 'label_2'

This TensorFlow script converts the training and validation data into
a sharded data set consisting of 1024 and 128 TFRecord files, respectively.

Each validation TFRecord file contains ~500 records. Each training TFREcord
file contains ~1000 records. Each record within the TFRecord file is a
serialized Example proto. The Example proto contains the following fields:

    image/encoded: string containing PNG encoded image in RGB colorspace
    image/height: integer, image height in pixels
    image/width: integer, image width in pixels
    image/channels: integer, specifying the number of channels, always 3
    image/format: string, specifying the format, always'PNG'

    image/object/bbox/xmin: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/xmax: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/ymin: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/ymax: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/label: list of integer specifying the classification index.
    image/object/bbox/label_text: list of string descriptions.

Note that the length of xmin is identical to the length of xmax, ymin and ymax
for each example.
"""
import os
import os.path
import sys
import random

import numpy as np
import tensorflow as tf

from datasets.dataset_utils import int64_feature, float_feature, bytes_feature
from datasets.kitti_common import KITTI_LABELS

DEFAULT_IMAGE_DIR = 'image_2/'
DEFAULT_LABEL_DIR = 'label_2/'

# TFRecords convertion parameters.
RANDOM_SEED = 4242
SAMPLES_PER_FILES = 512

def _png_image_shape(image_data, sess, decoded_png, inputs):
    rimg = sess.run(decoded_png, feed_dict={inputs: image_data})
    return rimg.shape


def _process_image(directory, name, f_png_image_shape,
                   image_dir=DEFAULT_IMAGE_DIR, label_dir=DEFAULT_LABEL_DIR):
    """Process a image and annotation file.

    Args:
      directory: KITTI dataset directory;
      name: file name.
    Returns:
      image_buffer: string, JPEG encoding of RGB image.
      height: integer, image height in pixels.
      width: integer, image width in pixels.
    """
    # Read the PNG image file.
    filename = os.path.join(directory, image_dir, name + '.png')
    image_data = tf.gfile.FastGFile(filename, 'rb').read()
    shape = list(f_png_image_shape(image_data))

    # Get object annotations.
    labels = []
    labels_text = []
    truncated = []
    occluded = []
    alpha = []
    bboxes = []
    dimensions = []
    locations = []
    rotation_y = []

    # Read the txt label file, if it exists.
    filename = os.path.join(directory, label_dir, name + '.txt')
    if os.path.exists(filename):
        with open(filename) as f:
            label_data = f.readlines()
        for l in label_data:
            data = l.split()
            if len(data) > 0:
                # Label.
                labels.append(int(KITTI_LABELS[data[0]][0]))
                labels_text.append(data[0].encode('ascii'))
                # truncated, occluded and alpha.
                truncated.append(float(data[1]))
                occluded.append(int(data[2]))
                alpha.append(float(data[3]))
                # bbox.
                bboxes.append((float(data[4]) / shape[1],
                               float(data[5]) / shape[0],
                               float(data[6]) / shape[1],
                               float(data[7]) / shape[0]
                               ))
                # 3D dimensions.
                dimensions.append((float(data[8]),
                                   float(data[9]),
                                   float(data[10])
                                   ))
                # 3D location and rotation_y.
                locations.append((float(data[11]),
                                  float(data[12]),
                                  float(data[13])
                                  ))
                rotation_y.append(float(data[14]))

    return (image_data, shape, labels, labels_text, truncated, occluded,
            alpha, bboxes, dimensions, locations, rotation_y)


def _convert_to_example(image_data, shape, labels, labels_text,
                        truncated, occluded, alpha, bboxes,
                        dimensions, locations, rotation_y):
    """Build an Example proto for an image example.

    Args:
      image_data: string, PNG encoding of RGB image;
      labels: list of integers, identifier for the ground truth;
      labels_text: list of strings, human-readable labels;
      bboxes: list of bounding boxes; each box is a list of integers;
          specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong
          to the same label as the image label.
      shape: 3 integers, image shapes in pixels.
    Returns:
      Example proto
    """
    # Transpose bboxes, dimensions and locations.
    bboxes = list(map(list, zip(*bboxes)))
    dimensions = list(map(list, zip(*dimensions)))
    locations = list(map(list, zip(*locations)))
    # Iterators.
    it_bboxes = iter(bboxes)
    it_dims = iter(dimensions)
    its_locs = iter(locations)

    image_format = b'PNG'
    example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': int64_feature(shape[0]),
            'image/width': int64_feature(shape[1]),
            'image/channels': int64_feature(shape[2]),
            'image/shape': int64_feature(shape),
            'image/format': bytes_feature(image_format),
            'image/encoded': bytes_feature(image_data),
            'object/label': int64_feature(labels),
            'object/label_text': bytes_feature(labels_text),
            'object/truncated': float_feature(truncated),
            'object/occluded': int64_feature(occluded),
            'object/alpha': float_feature(alpha),
            'object/bbox/xmin': float_feature(next(it_bboxes, [])),
            'object/bbox/ymin': float_feature(next(it_bboxes, [])),
            'object/bbox/xmax': float_feature(next(it_bboxes, [])),
            'object/bbox/ymax': float_feature(next(it_bboxes, [])),
            'object/dimensions/height': float_feature(next(it_dims, [])),
            'object/dimensions/width': float_feature(next(it_dims, [])),
            'object/dimensions/length': float_feature(next(it_dims, [])),
            'object/location/x': float_feature(next(its_locs, [])),
            'object/location/y': float_feature(next(its_locs, [])),
            'object/location/z': float_feature(next(its_locs, [])),
            'object/rotation_y': float_feature(rotation_y),
            }))
    return example


def _add_to_tfrecord(dataset_dir, name, tfrecord_writer, f_png_image_shape,
                     image_dir=DEFAULT_IMAGE_DIR, label_dir=DEFAULT_LABEL_DIR):
    """Loads data from image and annotations files and add them to a TFRecord.

    Args:
      dataset_dir: Dataset directory;
      name: Image name to add to the TFRecord;
      tfrecord_writer: The TFRecord writer to use for writing.
    """
    l_data = _process_image(dataset_dir, name, f_png_image_shape,
                            image_dir, label_dir)
    example = _convert_to_example(*l_data)
    tfrecord_writer.write(example.SerializeToString())

'''
def _get_output_filename(output_dir, name):
    return '%s/%s.tfrecord' % (output_dir, name)
'''
def _get_output_filename(output_dir, name, idx):
    return '%s/%s_%03d.tfrecord' % (output_dir, name, idx)


def run(dataset_dir, output_dir, name='kitti_train', shuffling=False):
    """Runs the conversion operation.

    Args:
      dataset_dir: The dataset directory where the dataset is stored.
      output_dir: Output directory.
    """
    if not tf.gfile.Exists(dataset_dir):
        tf.gfile.MakeDirs(dataset_dir)
    '''
    tf_filename = _get_output_filename(output_dir, name)
    if tf.gfile.Exists(tf_filename):
        print('Dataset files already exist. Exiting without re-creating them.')
        # return
    '''
    # Dataset filenames, and shuffling.
    path = os.path.join(dataset_dir, DEFAULT_IMAGE_DIR)
    filenames = sorted(os.listdir(path))
    if shuffling:
        random.seed(RANDOM_SEED)
        random.shuffle(filenames)

    # PNG decoding.
    inputs = tf.placeholder(dtype=tf.string)
    decoded_png = tf.image.decode_png(inputs)
    with tf.Session() as sess:
        fidx = 0
        i=0
        while(i<len(filenames)):
            tf_filename = _get_output_filename(output_dir, name,fidx) #获取文件名
        #    print(tf_filename)
            # Process dataset files.
            with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
                j=0
                while i < len(filenames) and j < SAMPLES_PER_FILES:
                    sys.stdout.write('\r>> Converting image %d/%d' % (i+1, len(filenames)))
                    sys.stdout.flush()
                    filename = filenames[i]
                    img_name = filename[:-4]
                    _add_to_tfrecord(dataset_dir, img_name, tfrecord_writer,
                                     lambda x: _png_image_shape(x, sess, decoded_png, inputs))
                    i += 1
                    j += 1
                fidx += 1
    '''
                for i, filename in enumerate(filenames):
                    sys.stdout.write('\r>> Converting image %d/%d' % (i+1, len(filenames)))
                    sys.stdout.flush()

                    name = filename[:-4]
                    _add_to_tfrecord(dataset_dir, name, tfrecord_writer,
                                     lambda x: _png_image_shape(x, sess, decoded_png, inputs))
                                     '''
    print('\nFinished converting the KITTI dataset!')

然后数据集代码部分就完成了,接下来需要改一下dataset_factory.py文件: 在末尾像我这样加一条
datasets_map = {
    'cifar10': cifar10,
    'imagenet': imagenet,
    'pascalvoc_2007': pascalvoc_2007,
    'pascalvoc_2012': pascalvoc_2012,
    'kitti': kitti,
}

我貌似就改了那么多,如果还有问题就留言哈
切换到SSD-Tensorflow目录,然后我的训练命令:
DATASET_DIR=tfrecords_kitti
TRAIN_DIR=logs/finetune_kitti/
CHECKPOINT_PATH=./checkpoints/ssd_300_vgg.ckpt
python3 train_ssd_network.py \
    --train_dir=${TRAIN_DIR} \
    --dataset_dir=${DATASET_DIR} \
    --dataset_name=kitti \
    --dataset_split_name=train \
    --model_name=ssd_300_vgg \
    --checkpoint_path=${CHECKPOINT_PATH} \
    --save_summaries_secs=600 \
    --save_interval_secs=600 \
    --weight_decay=0.0005 \
    --optimizer=adam \
    --learning_rate=0.001 \
    --batch_size=16 \
    --gpu_memory_fraction=0.5

切换到SSD-Tensorflow目录,把下载的KITTI数据集放在KITTI/目录下,然后,数据集转换命令为:
DATASET_DIR=./KITTI/training/
OUTPUT_DIR=./tfrecords_kitti
python3 tf_convert_data.py \
    --dataset_name=kitti \
    --dataset_dir=${DATASET_DIR} \
    --output_name=kitti_train \
    --output_dir=${OUTPUT_DIR}


然后就行了

参考文献

[1].SDC-Vehicle-Detection.https://github.com/balancap/SDC-Vehicle-Detection



  • 1
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 35
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 35
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

农民小飞侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值