bundle_adjustment_ceres.cpp
#include <iostream>
#include <ceres/ceres.h>
#include "common.h"
#include "SnavelyReprojectionError.h"
using namespace std;
void SolveBA(BALProblem &bal_problem);//函数声明
int main(int argc, char **argv)
{
if (argc != 2)
{
cout << "usage: bundle_adjustment_ceres bal_data.txt" << endl;
return 1;
}
BALProblem bal_problem(argv[1]);//读取文件
bal_problem.Normalize();//对数据做归一化处理
bal_problem.Perturb(0.1, 0.5, 0.5);//添加噪声
bal_problem.WriteToPLYFile("initial.ply");//写入初始文件便于与结果文件作对比
SolveBA(bal_problem);//求解BA问题
bal_problem.WriteToPLYFile("final.ply");//把最终结果写入final.ply文件
return 0;
}
void SolveBA(BALProblem &bal_problem) {
const int point_block_size = bal_problem.point_block_size();//point的大小
const int camera_block_size = bal_problem.camera_block_size();//camera的大小
double *points = bal_problem.mutable_points();
double *cameras = bal_problem.mutable_cameras();
// Observations is 2 * num_observations long array observations
// [u_1, u_2, ... u_n], where each u_i is two dimensional, the x
// and y position of the observation.
const double *observations = bal_problem.observations();
ceres::Problem problem;//ceres构建最小二乘问题
for (int i = 0; i < bal_problem.num_observations(); ++i)
{
ceres::CostFunction *cost_function;
// Each Residual block takes a point and a camera as input
// and outputs a 2 dimensional Residual
//每个残差块中3D点以及camera作为输入,输出2维的残差
//SnavelyReprojectionError作为重投影误差
cost_function = SnavelyReprojectionError::Create(observations[2 * i + 0], observations[2 * i + 1]);
// If enabled use Huber's loss function.
ceres::LossFunction *loss_function = new ceres::HuberLoss(1.0);
// Each observation corresponds to a pair of a camera and a point
// which are identified by camera_index()[i] and point_index()[i]
// respectively.
double *camera = cameras + camera_block_size * bal_problem.camera_index()[i];
double *point = points + point_block_size * bal_problem.point_index()[i];
//向问题中添加残差项
problem.AddResidualBlock(cost_function, loss_function, camera, point);
}
// 输出一些显示信息
std::cout << "bal problem file loaded..." << std::endl;
std::cout << "bal problem have " << bal_problem.num_cameras() << " cameras and "
<< bal_problem.num_points() << " points. " << std::endl;
std::cout << "Forming " << bal_problem.num_observations() << " observations. " << std::endl;
std::cout << "Solving ceres BA ... " << endl;
//配置并运行求解器
ceres::Solver::Options options;
options.linear_solver_type = ceres::LinearSolverType::SPARSE_SCHUR;//求解类型选用稀疏消元的方法
options.minimizer_progress_to_stdout = true;//输出到cout
ceres::Solver::Summary summary;//优化信息
ceres::Solve(options, &problem, &summary);//求解
std::cout << summary.BriefReport()<< "\n";//输出优化的详细报告
}
SnavelyReprojectionError.h
#ifndef SnavelyReprojection_H
#define SnavelyReprojection_H
#include <iostream>
#include "ceres/ceres.h"
#include "rotation.h"
//代价函数的计算模型
class SnavelyReprojectionError {
public:
SnavelyReprojectionError(double observation_x, double observation_y) : observed_x(observation_x),
observed_y(observation_y) {}
//残差的计算
template<typename T>
bool operator()(const T *const camera,//()运算符重载
const T *const point,
T *residuals) const {
// camera[0,1,2] are the angle-axis rotation
T predictions[2];
CamProjectionWithDistortion(camera, point, predictions);
residuals[0] = predictions[0] - T(observed_x);
residuals[1] = predictions[1] - T(observed_y);
return true;
}
// camera : 9 dims array
// [0-2] : angle-axis rotation 旋转向量
// [3-5] : translation 平移矩阵
// [6-8] : camera parameter, [6] focal length, [7-8] second and forth order radial distortion
// point : 3D location.
// predictions : 2D predictions with center of the image plane.图像平面中心的2D预测点
//每个相机一共有9维参数来表示
//前六维表示相机的姿态,1维焦距,2维畸变参数
template<typename T>
static inline bool CamProjectionWithDistortion(const T *camera, const T *point, T *predictions)//带畸变的相机投影模型
//主要功能是求重投影的像素坐标
{
// Rodrigues' formula 罗德里德公式
T p[3];
AngleAxisRotatePoint(camera, point, p);//世界坐标转换成相机坐标
// camera[3,4,5] are the translation
p[0] += camera[3];
p[1] += camera[4];
p[2] += camera[5];
// Compute the center fo distortion 计算畸变的中心
T xp = -p[0] / p[2];
T yp = -p[1] / p[2];
// Apply second and fourth order radial distortion
const T &l1 = camera[7];//k1
const T &l2 = camera[8];//k2
T r2 = xp * xp + yp * yp;
T distortion = T(1.0) + r2 * (l1 + l2 * r2);//计算r(p)
const T &focal = camera[6];
/******************计算转换后的像素坐标*****************/
predictions[0] = focal * distortion * xp;
predictions[1] = focal * distortion * yp;
return true;
}
/**************************定义出投影误差模型*************************/
static ceres::CostFunction *Create(const double observed_x, const double observed_y)
{
//使用自动求导,模板参数:误差类型,输出维度为观测值的像素坐标2维,输入维度为camera9维,point3维
return (new ceres::AutoDiffCostFunction<SnavelyReprojectionError, 2, 9, 3>(
new SnavelyReprojectionError(observed_x, observed_y)));
}
private:
double observed_x;
double observed_y;
};
#endif // SnavelyReprojection.h
common.h
#pragma once
/********************************common头文件******************************************/
/// 从文件读入BAL dataset
//对原始数据的txt进行分割存储;
class BALProblem {
public:
/// load bal data from text file
explicit BALProblem(const std::string &filename, bool us