WIN10+TensorFlow下的Faster-RCNN训练测试
安装TensorFlow
参考
自行参考这篇博客安装
下载源码
配置需要的环境
1、安装Python开发包
pip install cython
pip install python-opencv
pip install easydict
2、opencv比较难安装,若出现问题可以在下面链接下载想对应的版本OpenCV镜像文件
我这里下载的是opencv_ python-3. 3.1. 11-cp36-cp36m win amd64. whl,然后在anaconda下执行
pip install opencv_ python-3. 3. 1. 11-cp36-cp36m win amd64. whl
进行编译
3、在anaconda命令窗口,进入Faster-RCNN-TensorFlow-Python3\data\coco\PythonAPI目录,执行编译指令以及安装指令:
python setup.py build_ext --inplace
python setup.py build_ext install
然后进入另一个目录Faster-RCNN-TensorFlow-Python3\lib\utils,执行编译指令:python setup.py build_ext --inplace
这里容易出现[building pycocotools. mask extension
error: Unable to find vcvarsall. bat ]错误,这个错误是因为没有Visual C++ 编译环境,需要安装编译环境。
解决办法:(1)最直接的方法是安装VS2015,可参考VS2015
(2)相对简单的方法是安装VisualCppBuildTools_Full.exe,相对于安装VS2015要简单许多,也会省很多空间。百度网盘链接VisualCppBuildTools_Full.exe文件
出现这个提示说明运行成功
下载VOC2007数据集
VOCtrainval_06-Nov-2007
VOCtest_06-Nov-2007
VOCdevkit_08-Jun-2007
将三个压缩文件解压到Faster-RCNN-TensorFlow-Python3\data文件夹下,得到一个名为VOCdevkit的文件夹,将起名字改为VOCDevkit2007。
下载vgg16已训练好的网络模型数据
下载地址vgg16.ckpt
将下载的网络模型放入Faster-RCNN-TensorFlow-Python3\data\imagenet_weights\vgg16.ckpt,其中imagenet_weights文件夹需要自己建立。
这里注意默认下载文件名字为vgg_16.ckpt,需要手动改为vgg16.ckpt否则会出错,也可以选择其他网络模型,对应下载地址为:其它网络模型
训练
训练模型的参数可以在Faster-RCNN-TensorFlow-Python3\lib\config文件夹里的config.py修改,包括训练的总步数、权重衰减、学习率、batch_size等参数。
tf.app.flags.DEFINE_float('weight_decay', 0.0005, "Weight decay, for regularization")
tf.app.flags.DEFINE_float('learning_rate', 0.001, "Learning rate")
tf.app.flags.DEFINE_float('momentum', 0.9, "Momentum")
tf.app.flags.DEFINE_float('gamma', 0.1, "Factor for reducing the learning rate")
tf.app.flags.DEFINE_integer('batch_size', 256, "Network batch size during training")
tf.app.flags.DEFINE_integer('max_iters', 40000, "Max iteration")
tf.app.flags.DEFINE_integer('step_size', 30000, "Step size for reducing the learning rate, currently only support one step")
tf.app.flags.DEFINE_integer('display', 10, "Iteration intervals for showing the loss during training, on command line interface")
tf.app.flags.DEFINE_string('initializer', "truncated", "Network initialization parameters")
tf.app.flags.DEFINE_string('pretrained_model', "./data/imagenet_weights/vgg16.ckpt", "Pretrained network weights")
tf.app.flags.DEFINE_boolean('bias_decay', False, "Whether to have weight decay on bias as well")
tf.app.flags.DEFINE_boolean('double_bias', True, "Whether to double the learning rate for bias")
tf.app.flags.DEFINE_boolean('use_all_gt', True, "Whether to use all ground truth bounding boxes for training, "
"For COCO, setting USE_ALL_GT to False will exclude boxes that are flagged as ''iscrowd''")
tf.app.flags.DEFINE_integer('max_size', 1000, "Max pixel size of the longest side of a scaled input image")
tf.app.flags.DEFINE_integer('test_max_size', 1000, "Max pixel size of the longest side of a scaled input image")
tf.app.flags.DEFINE_integer('ims_per_batch', 1, "Images to use per minibatch")
tf.app.flags.DEFINE_integer('snapshot_iterations', 5000, "Iteration to take snapshot")
参数调整完后,在Faster-RCNN-TensorFlow-Python3的目录下,运行 python train.py,就可以训练生成模型了。
模型训练结束后,在 Faster-RCNN-TensorFlow-Python3\default\voc_2007_trainval\default目录下可以看到训练的模型,一个迭代了40000次,迭代次数可在Faster-RCNN-TensorFlow-Python3\lib\config文件夹里的config.py修改。
在目录下新建output\vgg16\voc_2007_trainval\default文件,将训练生成的文件复制到该文件下,并改名如下:“vgg16.ckpt.meta”,如下图所示:
测试
运行demo.py,需要进行修改
1、将NETS中的“vgg16_faster_rcnn_iter_70000.ckpt”改成“vgg16”,如下所示;
NETS = {'vgg16': ('vgg16.ckpt',), 'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
2、将DATASETS中的“voc_2007_trainval+voc_2012_trainval”改为“voc_2007_trainval”,如下所示;
DATASETS = {'pascal_voc': ('voc_2007_trainval',), 'pascal_voc_0712': ('voc_2007_trainval',)}
3、将def parse_args()函数的两个default分别改成vgg16和pascal_voc,如下所示;
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
choices=NETS.keys(), default='vgg16')
parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
choices=DATASETS.keys(), default='pascal_voc')
args = parser.parse_args()
return args
测试结果如下: