看全网还没有一篇攻略,本文是第一个,有心人当点赞下,有问题可以下方留言,互相交流。如转载请注明出处,不枉解决各种各样的bug
环境:
v100,cuda10.1,tensorflow2.1.0 ,python3.7.7
(只保证这个版本是可行的,其他尝试了很多,报了各种匪(kan)夷(bu)所(dong)思的bug)
我的数据集是iabelme标注的,文件格式是xml,所以要转成tfrecord格式
注:这个是官方版的,不是pytorch的,https://github.com/google/automl/tree/master/efficientdet
参照:
1.数据集准备
简单来说:就是用dataset下create_pascal_tfrecord.py进行转换,注意修改类名
命令:
PYTHONPATH=".:$PYTHONPATH" python dataset/create_pascal_tfrecord.py --data_dir=VOCdevkit --year=VOC2012 --output_path=mytfrecord/pascal --set=trainval
2.训练:
python main.py --mode=train_and_eval \
--num_classes=10 \
--training_file_pattern=mytfrecord/train*.tfrecord \
--validation_file_pattern=mytfrecord/val*.tfrecord \
--val_json_file=mytfrecord/json_val.json \
--model_name=efficientdet-d3 \
--model_dir=tmp/efficientdet-d3-scratch \
--ckpt=efficientdet-d3 \
--train_batch_size=4 \
--eval_batch_size=4 --eval_samples=1024 \
--hparams="use_bfloat16=false,num_classes=10,moving_average_decay=0" \
--use_tpu=False
需要注意的是num_classes,是自己的类数,模型也要提前下载解压好,放在根目录下。
还有model_dir前面的 '/'去掉,否则会到根目录,而不是当前目录(感觉有点坑,害的我以为一开始预测没成功)
还有训练不显示loss,需要tensorboard 显示
cd到model_dir文件夹下
tensorboard --logdir=./,这将打开该 *. tfevents文件夹下的文件
3.map测试
python main.py --mode=eval --num_classes=10 --training_file_pattern=mytfrecord/train*.tfrecord --validation_file_pattern=mytfrecord/val*.tfrecord --val_json_file=mytfrecord/json_val.json --model_name=efficientdet-d3 --model_dir=tmp/efficientdet-d3 --ckpt=/data/test/automl/efficientdet/efficientdet-d3 --train_batch_size=2 --eval_batch_size=2 --eval_samples=1024 --hparams="use_bfloat16=false,num_classes=10,moving_average_decay=0" --use_tpu=False
结果大概是这个样子:
4.用自己的图片预测
python model_inspect.py --runmode=infer --model_name=efficientdet-d3 --input_image_size=1920 --max_boxes_to_draw=100 --min_score_thresh=0.4 --ckpt_path=tmp/efficientdet-d3 --input_image=testdata/4.jpg --output_image_dir=outimg/ --num_classes=10 --enable_ema=False
一定要加后面两个,别问我为啥,我也不知道,整了好久,issue差点都翻了一遍
好吧你们要的理由:
可参考https://github.com/google/automl/issues/249
哦,对了,预测的时候,请把inference.py的类名改回来。
没什么问题的话,跟上图差不多就预测完成了。
预测的图片在output_image_dir 下一个叫0.jpg的,看名字不开心的自己去改^_^
还有如果是一堆图片的,自己用inference.py改改