上一章我们了解清楚了计算误差的接口是如何实现的,这一章我们来分析一下如何实现BA的搭建。
首先,依旧是把源码完全附上
#include <iostream>
#include <ceres/ceres.h>
#include "common.h"
#include "SnavelyReprojectionError.h"
using namespace std;
void SolveBA(BALProblem &bal_problem);
int main(int argc, char **argv) {
if (argc != 2) {
cout << "usage: bundle_adjustment_ceres bal_data.txt" << endl;
return 1;
}
BALProblem bal_problem(argv[1]);
bal_problem.Normalize();
bal_problem.Perturb(0.1, 0.5, 0.5);
bal_problem.WriteToPLYFile("initial.ply");
SolveBA(bal_problem);
bal_problem.WriteToPLYFile("final.ply");
return 0;
}
void SolveBA(BALProblem &bal_problem) {
const int point_block_size = bal_problem.point_block_size();
const int camera_block_size = bal_problem.camera_block_size();
double *points = bal_problem.mutable_points();
double *cameras = bal_problem.mutable_cameras();
// Observations is 2 * num_observations long array observations
// [u_1, u_2, ... u_n], where each u_i is two dimensional, the x
// and y position of the observation.
const double *observations = bal_problem.observations();
ceres::Problem problem;
for (int i = 0; i < bal_problem.num_observations(); ++i) {
ceres::CostFunction *cost_function;
// Each Residual block takes a point and a camera as input
// and outputs a 2 dimensional Residual
cost_function = SnavelyReprojectionError::Create(observations[2 * i + 0], observations[2 * i + 1]);
// If enabled use Huber's loss function.
ceres::LossFunction *loss_function = new ceres::HuberLoss(1.0);
// Each observation corresponds to a pair of a camera and a point
// which are identified by camera_index()[i] and point_index()[i]
// respectively.
double *camera = cameras + camera_block_size * bal_problem.camera_index()[i];
double *point = points + point_block_size * bal_problem.point_index()[i];
problem.AddResidualBlock(cost_function, loss_function, camera, point);
}
// show some information here ...
std::cout << "bal problem file loaded..." << std::endl;
std::cout << "bal problem have " << bal_problem.num_cameras() << " cameras and "
<< bal_problem.num_points() << " points. " << std::endl;
std::cout << "Forming " << bal_problem.num_observations() << " observations. " << std::endl;
std::cout << "Solving ceres BA ... " << endl;
ceres::Solver::Options options;
options.linear_solver_type = ceres::LinearSolverType::SPARSE_SCHUR;
options.minimizer_progress_to_stdout = true;
ceres::Solver::Summary summary;
ceres::Solve(options, &problem, &summary);
std::cout << summary.FullReport() << "\n";
}
1、首先我们来看一下构造problem时所用到的参数都是怎么计算的。
//构建problm
problem.AddResidualBlock(cost_function, loss_function, camera, point);
//cost_function 观测数据(误差)
cost_function = SnavelyReprojectionError::Create(observations[2 * i + 0], observations[2 * i + 1]);
//loss_function 核函数
ceres::LossFunction *loss_function = new ceres::HuberLoss(1.0);
//camera,point 式子中bal_problem.XXXXX为观测个数
double *camera = cameras + camera_block_size * bal_problem.camera_index()[i];
double *point = points + point_block_size * bal_problem.point_index()[i];
其中比较难以理解的是待优化参数camera和point的式子是怎么得来的。
2、我们来看一下两者式子中第一个参数cameras(points)代表着什么。
void SolveBA(BALProblem &bal_problem) {
const int point_block_size = bal_problem.point_block_size();
const int camera_block_size = bal_problem.camera_block_size();
double *points = bal_problem.mutable_points();
double *cameras = bal_problem.mutable_cameras();
按图索骥,我们来解释一下这个VALProblem的作用是什么
根据同站老哥的解释,这整个类的功能就是对原始的txt数据进行分割存储,然后提供对txt数据的读取写入和生成PLY文件功能。
//从这里可以看出来parameters_[ ]这个数组存储就是待优化的所有值,用法也是当个纯指针在用,因为下方用法全是指针加偏移量
//排列方式就是16个9维相机=144个数码一列,这些就是相机的值
//返回数据中相机位姿数据列的开头位置
const double* cameras() const{ return parameters_; }
//紧接着下面,从parameters_开始,加上上方的144偏移量,就到了路标的数据。
//返回路标点数据列的开头位置
const double* points() const{ return parameters_ + camera_block_size()//9 * num_cameras_//16; }
加上mutablb_的前缀只是让这些数据是可变的,而不是上述代码中加上了const的常量。
总的来说,指针camera指向了数据集中相机位姿的数据,指针points指向了数据集中路标点的数据。
3、那么加式中后半段(加数)camera_block_size * bal_problem.camera_index()[i] 就很好理解了,作用就是让指针指向下一个数据的位置。
double *camera = cameras + camera_block_size * bal_problem.camera_index()[i];
double *point = points + point_block_size * bal_problem.point_index()[i];
cost_function = SnavelyReprojectionError::Create(observations[2 * i + 0], observations[2 * i + 1]);
const double *observations = bal_problem.observations();
这个SnavelyReprojectionError类正是我们上一讲所分析的计算误差的类。