训练自己的数据集
环境配置
tensorflow-gpu == 1.15.0
keras == 2.3.1 (可以不用安装)
CUDA == 10.0
CUDNN == 7.6
1.生成数据集
在SSD目标检测中需要将VOC数据集转化成.tf格式。
运行tf_convert_data.py文件
#需要编辑的代码
import tensorflow as tf
from datasets import pascalvoc_to_tfrecords
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
'dataset_name', 'pascalvoc',
'The name of the dataset to convert.')
tf.app.flags.DEFINE_string(
'dataset_dir', "./VOC2007/test/",
'Directory where the original dataset is stored.')
tf.app.flags.DEFINE_string(
'output_name', 'voc_2012_train',
'Basename used for TFRecords output files.')
tf.app.flags.DEFINE_string(
'output_dir', './tfrecords',
'Output directory where to store TFRecords files.')
–dataset_name=pascalvoc(此处不需要更改,默认VOC数据)
–dataset_dir=./VOC2007/test(数据集保存位置)
–output_name=voc_2007_train (输出文件名)
–output_dir=./tfrecords(输出文件位置)
2.训练数据
tf.app.flags.DEFINE_float(
'loss_alpha', 1., 'Alpha parameter in the loss function.')
tf.app.flags.DEFINE_float(
'negative_ratio', 3., 'Negative ratio in the loss function.')
tf.app.flags.DEFINE_float(
'match_threshold', 0.5, 'Matching threshold in the loss function.')
# =========================================================================== #
# General Flags.
# =========================================================================== #
tf.app.flags.DEFINE_string(
'train_dir', './logs/',
'Directory where checkpoints and event logs are written to.')
tf.app.flags.DEFINE_integer('num_clones', 1,
'Number of model clones to deploy.')
#48 True->False
tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
'Use CPUs to deploy clones.')
tf.app.flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
tf.app.flags.DEFINE_integer(
'num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
tf.app.flags.DEFINE_integer(
'log_every_n_steps', 10,
'The frequency with which logs are print.')
tf.app.flags.DEFINE_integer(
'save_summaries_secs', 60,
'The frequency with which summaries are saved, in seconds.')
tf.app.flags.DEFINE_integer(
'save_interval_secs', 600,
'The frequency with which the model is saved, in seconds.')
tf.app.flags.DEFINE_float(
'gpu_memory_fraction', 0.7, 'GPU memory fraction to use.')
# ===========&