在slim中训练模型

这里主要是用slim下的Inception-V3模型对flower数据进行训练和评估.

数据准备:(默认已经安装过TF-slim image models library/下载过flower数据并将其转换为TFRecorder数据)

1)将models/research/slim文件夹(下图左)以及转换后的flowers数据拷贝到程序所在文件夹下(下图右)

2)训练模型的代码被放在slim/train_image_classifier.py文件中.下面编写shell脚本来对flower数据训练Inception-v3模型.

python train_image_classifier.py \
    --train_dir=flowers/train_dir \
    --dataset_name=flowers \
    --dataset_split_name=train \
    --dataset_dir=flowers/data \
    --model_name=inception_v3 \
    --checkpoint_path=flowers/pretrained/inception_v3.ckpt \
    --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
    --trainable_scope=InceptionV3/Logits,InceptionV3/AuxLogits \
    --max_number_of_steps=1000 \
    --batch_size=32 \
    --learning_rate=0.001 \
    --learning_rate_decay_type=fixed \
    --save_interval_secs=300 \
    --save_summaries_secs=2 \
    --log_every_n_steps=10 \
    --optiomizer=rmsprop \
    --weight_decay=0.00004

参数说明:
    --train_dir:训练完成后生成的模型所在路径
    --dataset_name:参与训练的数据集的名字
    --dataset_split_name:数据集中的validation或train
    --dataset_dir : 数据集路径
    --model_name : 模型名字
    --checkpoint_path : 下载的已经训练好的模型路径
    --checkpoint_exclude_scopes : 指定载入预训练模型时哪一层的权重不被载入
    --trainable_scope : 指定对哪一层参数进行训练;
    --max_number_of_steps : 最大运行步数
    --save_interval_secs=300 : 每300s保存一次模型到指定路径
    --save_summaries_secs=2  : 每2s将日志写入
    --log_every_n_steps=10 : 每10s打印一次

3)命令行中执行上面编写的shell脚本,

完成训练后得到训练后的模型在指定训练模型保存路径下

4)将最新获得的model.ckpt作为评测模型--在上图中即 model .ckpt-928,编写模型评测shell脚本如下:

python eval_image_classifier.py \
    --eval_dir=flowers/train_dir/  \
    --dataset_name=flowers \
    --dataset_split_name=validation \
    --dataset_dir=flowers/data \
    --model_name=inception_v3 \
    --checkpoint_path=flowers/train_dir

运行该评测脚本,得到最终模型评估结果如下:

最终评测准确率为87.25%.

总结一下:利用slim训练好的模型进行训练的步骤:1)数据准备(包括模型及数据);2)编写训练脚本,运行得到训练模型;3)编写评测脚本,运行得到评测结果.

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
这篇博客使用TensorFlow框架,使用预训练模型进行猫狗分类。代码在Github上已经公开,可以从[这里](https://github.com/jmhIcoding/dogsVScats)获取。 使用预训练模型进行微调的代码如下,其包括了数据处理、模型构建和训练三个部分: 数据处理部分[^1]: ```python import tensorflow as tf slim = tf.contrib.slim ... def get_dataset(dataset_name, split_name, dataset_dir, file_pattern): """ 获取指定数据集和指定数据集的数据切分 """ file_pattern = os.path.join(dataset_dir, file_pattern % split_name) if dataset_name == 'imagenet': return dataset.get_split(split_name, dataset_dir, file_pattern) elif dataset_name == 'flowers': return flowers.get_split(split_name, dataset_dir, file_pattern) elif dataset_name == 'cifar10': return cifar10.get_split(split_name, dataset_dir, file_pattern) elif dataset_name == 'mnist': return mnist.get_split(split_name, dataset_dir, file_pattern) elif dataset_name == 'cats_vs_dogs': return dogs.get_split(split_name, dataset_dir, file_pattern) else: raise ValueError('Invalid dataset name %s.' % dataset_name) def load_batch(dataset, batch_size, height, width, is_training=True): """ 加载一批数据 """ data_provider = slim.dataset_data_provider.DatasetDataProvider( dataset, shuffle=is_training, common_queue_capacity=2 * batch_size, common_queue_min=batch_size) image_raw, label = data_provider.get(['image', 'label']) image = inception_preprocessing.preprocess_image( image_raw, height, width, is_training=is_training) images, labels = tf.train.batch( [image, label], batch_size=batch_size, num_threads=4, capacity=5 * batch_size) return images, labels ``` 模型构建部分: ```python import tensorflow as tf slim = tf.contrib.slim ... def build_model(inputs, num_classes, is_training=True, scope='vgg_16'): """ 构建VGG16模型 """ with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc: end_points_collection = sc.name + '_end_points' with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], outputs_collections=end_points_collection): net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') net = slim.max_pool2d(net, [2, 2], scope='pool1') ... net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5') net = slim.max_pool2d(net, [2, 2], scope='pool5') net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout6') net = slim.conv2d(net, 4096, [1, 1], scope='fc7') net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout7') net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, normalizer_fn=None, scope='fc8') end_points = slim.utils.convert_collection_to_dict(end_points_collection) return net, end_points ``` 训练部分: ```python import tensorflow as tf slim = tf.contrib.slim ... def run_training(dataset_name, train_dir, dataset_dir, num_classes=2, batch_size=32, num_epochs=10, initial_learning_rate=0.0001): """ 训练模型 """ with tf.Graph().as_default(): tf.logging.set_verbosity(tf.logging.INFO) # 获取数据集 dataset = get_dataset(dataset_name, 'train', dataset_dir, '%s_*.tfrecord') images, labels = load_batch(dataset, batch_size=batch_size, height=224, width=224, is_training=True) # 构建网络 logits, end_points = build_model(images, num_classes=num_classes, is_training=True) # 定义损失函数 one_hot_labels = slim.one_hot_encoding(labels, num_classes) slim.losses.softmax_cross_entropy(logits=logits, onehot_labels=one_hot_labels) total_loss = slim.losses.get_total_loss() # 定义优化器 global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.exponential_decay( initial_learning_rate, global_step, decay_steps=1000, decay_rate=0.96, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) train_op = slim.learning.create_train_op(total_loss, optimizer, global_step=global_step) # 进行训练 saver = tf.train.Saver(tf.global_variables()) slim.learning.train( train_op, train_dir, log_every_n_steps=1, save_summaries_secs=20, saver=saver, number_of_steps=num_epochs * dataset.num_samples // batch_size, save_interval_secs=120) if __name__ == '__main__': run_training('cats_vs_dogs', '/tmp/cats_vs_dogs', '/path/to/dataset') ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值