项目
- 论文:Unsupervised Learning of Depth and Ego-Motion from Video
- 地址:https://openaccess.thecvf.com/content_cvpr_2017/html/Zhou_Unsupervised_Learning_of_CVPR_2017_paper.html
- 代码:https://github.com/tinghuiz/SfMLearner
运行环境
环境配置在上一篇博客中已经详细描述
- 系统:win10 64位
- GPU:GeForce RTX 2080 SUPER
- CUDA 10.0
- CuDNN 7.4.2
- Python 3.6
- Tensorflow-Gpu 1.13.1
运行demo
- 首先下载kitti_depth_model.tar模型,并放在SfMLearner-master\models文件夹里
- 修改了一下代码,用spyder运行,结果如下图
#demo.py
from __future__ import division
import os
import numpy as np
import PIL.Image as pil
import tensorflow as tf
from SfMLearner import SfMLearner
from utils import normalize_depth_for_display
from pylab import *
tf.reset_default_graph()
img_height=128
img_width=416
ckpt_file = 'models/model-190532'
fh = open('misc/sample.png', 'rb')
I = pil.open(fh)
I = I.resize((img_width, img_height), pil.ANTIALIAS)
I = np.array(I)
sfm = SfMLearner()
sfm.setup_inference(img_height,
img_width,
mode='depth')
saver = tf.train.Saver([var for var in tf.model_variables()])
with tf.Session() as sess:
saver.restore(sess, ckpt_file)
pred = sfm.inference(I[None,:,:,:], sess, mode='depth')
figure(figsize=(15,15))
subplot(1,2,1); imshow(I)
subplot(1,2,2); imshow(normalize_depth_for_display(pred['depth'][0,:,:,0]))
- 第一张为原来代码结果,第二张换了张图,试了一下效果。
准备训练数据
- 下载 kitti-raw-data 和 cityscapes 数据集
- 保持 static_frames.txt 里面的数据与 kitti-raw 对应
- 切换目录,在cmd里运行 此处scipy 版本应为1.2.1,不然会报错
python data/prepare_train_data.py --dataset_dir=data/kitti/kitti-raw/ --dataset_name=kitti_raw_eigen --dump_root=data/resulting/formatted/data_kitti/ --seq_length=3 --img_width=416 --img_height=128 --num_threads=4
- kitti 运行结果及目录,花了大概两个多小时跑完
训练
- 直接运行下面指令报错:UnrecognizedFlagError: Unknown command line flag 'num_source'
python train.py --dataset_dir=data/resulting/formatted/data_kitti/ --checkpoint_dir=checkpoints/ --img_width=416 --img_height=128 --batch_size=4
- 解决方法:在 train.py 文件里加入两行
flags.DEFINE_string("num_source", None, "add configuration")
flags.DEFINE_string("num_scales", None, "add configuration")
- 训练截图
- tensorboard可视化
tensorboard --logdir=checkpoints/ --port=8888
-
进入网址即可
KITTI数据评估
Depth:
- 下载 kitti_eigen_depth_predictions.npy 预测模型
python kitti_eval/eval_depth.py --kitti_dir=data/kitti/kitti-raw/ --pred_file=kitti_eval/kitti_eigen_depth_predictions.npy
- 由于本人使用的是python3.6,原代码应该是python2,所以 depth_evaluation_utils.py 需要做如下修改
- 将140行改为 data[key] = np.array(list(map(float, value.split(' '))))
- 将210行改为 dupe_inds = [item for item, count in Counter(inds).items() if count > 1],得出如下结果
- 第二组结果是加上作者限定的 --max_depth=50,得到的结果更好
Pose:
- 下载 pose_eval_data.tar 模型
python kitti_eval/eval_pose.py --gtruth_dir=kitti_eval/pose_eval_data/pose_data/ground_truth/10/ --pred_dir=kitti_eval/pose_eval_data/pose_data/ours_results/10/
KITTI测试
Depth:
- 下载 kitti_pose_model.tar 模型
- 修改 test_kitti_depth.py,将47行改为 fh = open(test_files[idx], 'rb')
python test_kitti_depth.py --dataset_dir data/kitti/kitti-raw/ --output_dir output/depth/ --ckpt_file models/kitti_depth_model/model-190532
Pose:
- 下载 pose_eval_data.tar 模型
- 下载 kitti 数据集中 data_odometry_color 部分
python test_kitti_pose.py --test_seq 9 --dataset_dir data/kitti/kitti-odometry/data_odometry_color/dataset/ --output_dir output/pose/ --ckpt_file models/kitti_pose_model/model-100280