算法原理
见附录链接
主要用途
该算法主要是用于医疗领域,用于去除背景算法,其中最关键的参数是半径选择。一个合适的半径参数选择会极大的提升效果,以下是c++版本实现,特提供出来供大家参考
代码
cpp文件
#include "pch.h"
#include "rolling_ball.h"
enum {
X_DIRECTION = 0,
Y_DIRECTION = 1,
DIAGONAL_1A = 2,
DIAGONAL_1B = 3,
DIAGONAL_2A = 4,
DIAGONAL_2B = 5
};
RollingBall::RollingBall(int radius) {
int trim_para;
if (radius <= 10) {
c_shrink_factor = 1;
trim_para = 24;
}
else if (radius <= 30) {
c_shrink_factor = 2;
trim_para = 24;
}
else if (radius <= 100) {
c_shrink_factor = 4;
trim_para = 32;
}
else
{
c_shrink_factor = 8;
trim_para = 40;
}
c_data.clear();
buildRollingBall(radius, trim_para);
}
void RollingBall::buildRollingBall(float ball_radius, int trim_para) {
double rsquare = 0; // 半径平方
int xtrim = 0;
int xval = 0, yval = 0;
double small_ball_radius = ball_radius / c_shrink_factor; // 下采样后球的半径,只有图片缩放后才会启用
int half_width;
if (small_ball_radius < 1) {
small_ball_radius = 1;
}
rsquare = small_ball_radius * small_ball_radius;
xtrim = (int)(trim_para * small_ball_radius) / 100;
half_width = (int)round(small_ball_radius - xtrim);
c_width = 2 * half_width + 1;
// c_data 数据大于0,说明在圆内,否则在圆外
for (int y = 0; y < c_width; y++) {
for (int x = 0; x < c_width; x++) {
xval = x - half_width;
yval = y - half_width;
float temp = rsquare - (float)xval * xval - (float)yval*yval;
float val = temp > 0 ? sqrt(temp) : 0;
c_data.push_back(val);
}
}
}
RollingBall::~RollingBall() {
}
static void smooth(cv::Mat& src,cv::Mat& dst, int kernel_size = 3) {
//cv::Mat kernel = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(kernel_size, kernel_size));
float val = 1.0 / (kernel_size * kernel_size);
cv::Mat kernel(kernel_size, kernel_size, CV_32FC1, cv::Scalar(val));
cv::filter2D(src, dst, src.type(), kernel);
}
float* lineSlideParabola(float* pixels, int start, int inc, int length, float coeff2, float* cache, int* nextPoint, float* correctedEdges) {
float minValue = std::numeric_limits<float>::max();
int lastpoint = 0;
int firstCorner = length - 1;
int lastCorner = 0;
float vPrevious1 = 0;
float vPrevious2 = 0;
// 二阶导数
float curvatureTest = 1.999f * coeff2;
for (int i = 0, p = start; i < length; i++, p += inc) {
float v = pixels[p];
cache[i] = v;
if (v < minValue) minValue = v;
// 曲率经验公式 p[k+1]+p[k-1]-2p[k]<curvatureTest
if (i >= 2 && vPrevious1 + vPrevious1 - vPrevious2 - v < curvatureTest) {
// 保存下一个可能接触点的位置,加速计算
nextPoint[lastpoint] = i - 1;
lastpoint = i - 1;
}
vPrevious2 = vPrevious1;
vPrevious1 = v;
}
nextPoint[lastpoint] = length - 1;
nextPoint[length - 1] = std::numeric_limits<int>::max();
int i1 = 0;
while (i1 < length - 1) {
float v1 = cache[i1];
float minSlope = std::numeric_limits<float>::max();
int i2 = 0;
int searchTo = length;
// 性能项,避免重新计算
int recalculateLimitNow = 0;
// 找到通过点 i1 与图像相交的第二个点 i2
for (int j = nextPoint[i1]; j < searchTo; j = nextPoint[j], recalculateLimitNow++) {
float v2 = cache[j];
// 斜率(v2 - v1) / (j - i1) 加上曲率项 coeff2 * (j - i1),即在j点处抛物线的导数
float slope = (v2 - v1) / (j - i1) + coeff2 * (j - i1);
if (slope < minSlope) {
minSlope = slope;
i2 = j;
recalculateLimitNow = -3;
}
if (recalculateLimitNow == 0) {
double b = 0.5 * minSlope / coeff2;
int maxSearch = i1 + (int)(b + sqrt(b * b + (v1 - minValue) / coeff2) + 1);
if (maxSearch < searchTo && maxSearch > 0) searchTo = maxSearch;
}
}
if (i1 == 0) firstCorner = i2;
if (i2 == length - 1) lastCorner = i1;
// 插值更新
for (int j = i1 + 1, p = start + j * inc; j < i2; j++, p += inc)
pixels[p] = v1 + (j - i1) * (minSlope - (j - i1) * coeff2);
i1 = i2;
}
if (correctedEdges != nullptr) {
// 判断图像是否在边缘
if (4 * firstCorner >= length) firstCorner = 0;
if (4 * (length - 1 - lastCorner) >= length) lastCorner = length - 1;
float v1 = cache[firstCorner];
float v2 = cache[lastCorner];
// 斜率
float slope = (v2 - v1) / (lastCorner - firstCorner);
// 截距
float value0 = v1 - slope * firstCorner;
// 计算6阶多项式系数
float coeff6 = 0;
// 计算两个边缘点的中点
float mid = 0.5 * (lastCorner + firstCorner);
// 遍历图像中间像素 1/3 到 2/3
for (int i = (length + 2) / 3; i <= (2 * length) / 3; i++) {
// 计算当前像素与中点的相对位置,并标准化到[-1, 1]范围
float dx = ((float)i - mid) * 2.0 / ((float)lastCorner - firstCorner);
// 计算6阶多项式,它在firstCorner和lastCorner处为0
float poly6 = dx * dx * dx * dx * dx * dx - 1.0;
// 如果当前像素值小于通过线性插值和当前coeff6计算的多项式插值,则更新coeff6
if (cache[i] < value0 + slope * i + coeff6 * poly6) {
coeff6 = -(value0 + slope * i - cache[i]) / poly6;
}
}
// 矫正边缘
float dx = ((float)firstCorner - mid) * 2.0 / ((float)lastCorner - firstCorner);
correctedEdges[0] = value0 + coeff6 * ((float)dx * dx * dx * dx * dx * dx - 1.0) + (float)coeff2 * firstCorner * firstCorner;
dx = ((float)lastCorner - mid) * 2.0 / ((float)lastCorner - firstCorner);
correctedEdges[1] = value0 + (length - 1) * slope + coeff6 * ((float)dx * dx * dx * dx * dx * dx - 1.0) + (float)coeff2 * (length - 1 - lastCorner) * (length - 1 - lastCorner);
}
return correctedEdges;
}
void filter1D(Mat& src, int direction, float coeff2, float* cache, int* nextPoint) {
float* pixels = (float*)src.data;
int width = src.cols;
int height = src.rows;
int startLine = 0;
int nLines = 0;
int lineInc = 0;
int pointInc = 0;
int length = 0;
switch (direction) {
case X_DIRECTION: //lines parallel to x direction
nLines = height;
lineInc = width;
pointInc = 1;
length = width;
break;
case Y_DIRECTION: //lines parallel to y direction
nLines = width;
lineInc = 1;
pointInc = width;
length = height;
break;
case DIAGONAL_1A: //lines parallel to x=y, starting at x axis
nLines = width - 2; //the algorithm makes no sense for lines shorter than 3 pixels
lineInc = 1;
pointInc = width + 1;
break;
case DIAGONAL_1B: //lines parallel to x=y, starting at y axis
startLine = 1;
nLines = height - 2;
lineInc = width;
pointInc = width + 1;
break;
case DIAGONAL_2A: //lines parallel to x=-y, starting at x axis
startLine = 2;
nLines = width;
lineInc = 1;
pointInc = width - 1;
break;
case DIAGONAL_2B: //lines parallel to x=-y, starting at x=width-1, y=variable
startLine = 0;
nLines = height - 2;
lineInc = width;
pointInc = width - 1;
break;
}
for (int i = startLine; i < nLines; i++) {
int startPixel = i * lineInc;
if (direction == DIAGONAL_2B) startPixel += width - 1;
switch (direction) {
case DIAGONAL_1A: length = min(height, width - i); break;
case DIAGONAL_1B: length = min(width, height - i); break;
case DIAGONAL_2A: length = min(height, i + 1); break;
case DIAGONAL_2B: length = min(width, height - i); break;
}
lineSlideParabola(pixels, startPixel, pointInc, length, coeff2, cache, nextPoint, nullptr);
}
}
void correct_corners(Mat& src, float coeff2, float* cache, int* nextPoint) {
int width = src.cols;
int height = src.rows;
float* pixels = (float*)src.data;
float* corners = new float[4]; //(0,0) (xmax,0) (ymax,0) (xmax,ymax)
float* correctedEdges = new float[2];
// 第一行
correctedEdges = lineSlideParabola(pixels, 0, 1, width, coeff2, cache, nextPoint, correctedEdges);
corners[0] = correctedEdges[0];
corners[1] = correctedEdges[1];
// 最后一行
correctedEdges = lineSlideParabola(pixels, (height - 1) * width, 1, width, coeff2, cache, nextPoint, correctedEdges);
corners[2] = correctedEdges[0];
corners[3] = correctedEdges[1];
// 第一列
correctedEdges = lineSlideParabola(pixels, 0, width, height, coeff2, cache, nextPoint, correctedEdges);
corners[0] += correctedEdges[0];
corners[2] += correctedEdges[1];
// 最后一列
correctedEdges = lineSlideParabola(pixels, width - 1, width, height, coeff2, cache, nextPoint, correctedEdges);
corners[1] += correctedEdges[0];
corners[3] += correctedEdges[1];
int diagLength = min(width, height); //length of a 45-degree line from a corner
float coeff2diag = 2 * coeff2;
// 左上到右下(长宽相等则正好到右下)
correctedEdges = lineSlideParabola(pixels, 0, 1 + width, diagLength, coeff2diag, cache, nextPoint, correctedEdges);
corners[0] += correctedEdges[0];
// 右上到左下
correctedEdges = lineSlideParabola(pixels, width - 1, -1 + width, diagLength, coeff2diag, cache, nextPoint, correctedEdges);
corners[1] += correctedEdges[0];
// 左下到右上
correctedEdges = lineSlideParabola(pixels, (height - 1) * width, 1 - width, diagLength, coeff2diag, cache, nextPoint, correctedEdges);
corners[2] += correctedEdges[0];
// 右下到左上
correctedEdges = lineSlideParabola(pixels, width * height - 1, -1 - width, diagLength, coeff2diag, cache, nextPoint, correctedEdges);
corners[3] += correctedEdges[0];
if (pixels[0] > corners[0] / 3) pixels[0] = corners[0] / 3;
if (pixels[width - 1] > corners[1] / 3) pixels[width - 1] = corners[1] / 3;
if (pixels[(height - 1) * width] > corners[2] / 3) pixels[(height - 1) * width] = corners[2] / 3;
if (pixels[width * height - 1] > corners[3] / 3) pixels[width * height - 1] = corners[3] / 3;
delete[] corners;
delete[] correctedEdges;
}
void sliding_paraboloid_float_background(const Mat& src, float radius,Mat& background_img, bool correctCorner) {
Mat src_copy = src.clone();
float* pixels = (float*)src_copy.data;
int width = src_copy.cols;
int height = src_copy.rows;
float* cache = new float[max(width, height)];
int* nextPoint = new int[max(width, height)];
float coeff2 = 0.5 / radius; //二项式阶数,越小越尖锐
float coeff2diag = 1.0 / radius; //对角线上二项式阶数
if (correctCorner)
correct_corners(src_copy, coeff2, cache, nextPoint);
// 演不同方向滑动抛物线
filter1D(src_copy, X_DIRECTION, coeff2, cache, nextPoint);
filter1D(src_copy, Y_DIRECTION, coeff2, cache, nextPoint);
filter1D(src_copy, X_DIRECTION, coeff2, cache, nextPoint); //redo for better accuracy
filter1D(src_copy, DIAGONAL_1A, coeff2diag, cache, nextPoint);
filter1D(src_copy, DIAGONAL_1B, coeff2diag, cache, nextPoint);
filter1D(src_copy, DIAGONAL_2A, coeff2diag, cache, nextPoint);
filter1D(src_copy, DIAGONAL_2B, coeff2diag, cache, nextPoint);
filter1D(src_copy, DIAGONAL_1A, coeff2diag, cache, nextPoint);//redo for better accuracy
filter1D(src_copy, DIAGONAL_1B, coeff2diag, cache, nextPoint);
background_img = src_copy;
}
static void interpolation_arrays(int* p_small_index, float* weight, int length, int small_length, int shrink_factor) {
for (int i = 0; i < length; i++) {
int small_index = (i - shrink_factor / 2) / shrink_factor;
if (small_index >= small_length - 1) small_index = small_length - 2;
p_small_index[i] = small_index;
float distance = (i + 0.5f) / shrink_factor - (small_index + 0.5f); //distance of pixel centers (in smallImage pixels)
weight[i] = 1.0 - distance;
}
}
static void shrink_img(const cv::Mat& src, cv::Mat& dst, int shrink_factor) {
int height = src.rows;
int width = src.cols;
// 向上取整
int s_height = (height + shrink_factor - 1) / shrink_factor;
int s_width = (width + shrink_factor - 1) / shrink_factor;
dst = cv::Mat::zeros(s_height,s_width, CV_32FC1);
float* data = (float*)dst.data;
float* pixels = (float*)src.data;
float min, thispixel;
for (int y_small = 0; y_small < s_height; y_small++) {
for (int x_small = 0; x_small < s_width; x_small++) {
min = std::numeric_limits<float>::max();
// 遍历 shrink_factor * shrink_factor 领域内所有点,并找到最小值
for (int j = 0, y = shrink_factor * y_small; j < shrink_factor && y < height; j++, y++) {
for (int k = 0, x = shrink_factor * x_small; k < shrink_factor && x < width; k++, x++) {
thispixel = pixels[x + y * width];
if (thispixel < min)
min = thispixel;
}
}
data[x_small + y_small * s_width] = min;
}
}
}
static void enlarge_img(const cv::Mat& src, cv::Mat& dst, int enlarge_factor){
int height = src.rows;
int width = src.cols;
int l_height = dst.rows;
int l_width = dst.cols;
cv::Mat temp = Mat::zeros(dst.size(), CV_32FC1);
float* src_data = (float*)src.data;
float* dst_data = (float*)temp.data;
float* line0 = new float[l_width];
float* line1 = new float[l_width];
int* x_indices = new int[l_width];
float* x_weight = new float[l_width];
interpolation_arrays(x_indices, x_weight, l_width, width, enlarge_factor);
int* y_indices = new int[l_height];
float* y_weight = new float[l_height];
interpolation_arrays(y_indices, y_weight, l_height, height, enlarge_factor);
for (int x = 0; x < l_width; x++)
line1[x] = src_data[x_indices[x]] * x_weight[x] + src_data[x_indices[x] + 1] * (1.0 - x_weight[x]);
int y_small_line0 = -1;
for (int y = 0; y < l_height; y++) {
if (y_small_line0 < y_indices[y]) {
float* swap = line0;
line0 = line1;
line1 = swap;
y_small_line0++;
int sy_pointer = (y_indices[y] + 1) * width;
for (int x = 0; x < l_width; x++)
line1[x] = src_data[sy_pointer + x_indices[x]] * x_weight[x] + src_data[sy_pointer + x_indices[x] + 1] * (1.0 - x_weight[x]);
}
float weight = y_weight[y];
for (int x = 0, p = y * l_width; x < l_width; x++, p++)
dst_data[p] = line0[x] * weight + line1[x] * (1.0 - weight);
}
dst = temp;
delete [] line0;
delete [] line1;
delete [] x_indices;
delete [] y_indices;
delete [] x_weight;
delete [] y_weight;
}
static void roll_ball(cv::Mat& shrink_img, RollingBall ball,cv::Mat& background_img) {
int width = shrink_img.cols;
int height = shrink_img.rows;
uchar* pdata = shrink_img.data;
float* pix_data = (float*)shrink_img.data;
std::vector<float> z_data = ball.c_data;
int ball_width = ball.c_width;
int radius = ball_width / 2;
float* cache_data = new float[width * ball_width];
for (int y = -radius; y < height + radius; y++) {
int next_line2write_cache = (y + radius) % ball_width;
int next_line2read = y + radius;
if (next_line2read < height && next_line2read >=0) {
std::memcpy(cache_data + (int)next_line2write_cache * width, pix_data + (int)next_line2read * width, width * sizeof(float));
for (int i = 0, p = next_line2read * width; i < width; i++, p++) {
pix_data[p] = std::numeric_limits<float>::min();
}
}
// 从 ball 数据中索引球的z信息,如果球超出图像边界,则截取部分
int y0 = y - radius;
y0 = y0 < 0 ? 0 : y0;
int yball0 = y0 - y + radius;
int yend = y + radius;
if (yend >= height)
yend = height - 1;
for (int x = -radius; x < width + radius; x++) {
double z = DBL_MAX;
int x0 = x - radius;
x0 = x0 < 0 ? 0 : x0;
int xball0 = x0 - x + radius;
int xend = x + radius;
if (xend >= width) xend = width - 1;
for (int yp = y0, yBall = yball0; yp <= yend; yp++, yBall++) {
int cachePointer = (yp % ball_width) * width + x0;
for (int xp = x0, bp = xball0 + yBall * ball_width; xp <= xend; xp++, cachePointer++, bp++) {
float zReduced = cache_data[cachePointer] - z_data[bp];
if (z > zReduced) z = zReduced;
}
}
for (int yp = y0, yBall = yball0; yp <= yend; yp++, yBall++)
for (int xp = x0, p = xp + yp * width, bp = xball0 + yBall * ball_width; xp <= xend; xp++, p++, bp++) {
float zMin = z + z_data[bp];
if (pix_data[p] < zMin) pix_data[p] = zMin;
}
}
}
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
shrink_img.at<float>(i, j) = pix_data[i * width + j];
}
}
background_img = shrink_img.clone();
}
static void rolling_ball_float_background(cv::Mat& src, RollingBall ball, cv::Mat& background_img) {
// shrink img
background_img = src.clone();
cv::Mat shr_img = src.clone(),back;
shrink_img(src, shr_img, ball.c_shrink_factor);
roll_ball(shr_img, ball, back);
// enlarge img
enlarge_img(back, background_img, ball.c_shrink_factor);
//resize(back, background_img, Size(0,0),ball.c_shrink_factor,ball.c_shrink_factor);
}
void subtract_background_rolling_ball(cv::Mat src, int radius, cv::Mat& roll_back_img,
cv::Mat& background_img, bool isLightBack, bool isSmooth, bool isParaboloid) {
// 滤波
if (isSmooth) {
smooth(src, src);
}
// 图片翻转
if (isLightBack) {
gray_reverse(src);
}
cv::Mat src_copy;
src.convertTo(src_copy, CV_32FC1);
if (isParaboloid) {
// 抛物线方法
sliding_paraboloid_float_background(src_copy, radius, background_img, false);
}
else {
RollingBall ball(radius);
rolling_ball_float_background(src_copy, ball, background_img);
}
roll_back_img = src_copy - background_img;
roll_back_img.convertTo(roll_back_img, CV_8UC1);
background_img.convertTo(background_img, CV_8UC1);
if (isLightBack) {
gray_reverse(roll_back_img);
gray_reverse(background_img);
}
}
h文件
#pragma once
// 滚动的球类,包含图像放缩函数
class RollingBall
{
public:
std::vector<float> c_data;
int c_width;
int c_shrink_factor;
RollingBall(int radius);
~RollingBall();
private:
void buildRollingBall(float ball_radius, int trim_para);
};
void subtract_background_rolling_ball(cv::Mat src, int radius, cv::Mat& roll_back_img, cv::Mat& background_img, bool isLightBack=true, bool isSmooth = true, bool isParaboloid = true);
pch文件
// pch.cpp: 与预编译标头对应的源文件
#include "pch.h"
void show_img(const Mat& input, string title) {
namedWindow(title, WINDOW_NORMAL);
imshow(title, input);
waitKey();
}
void gray_reverse(Mat& des)
{
int totalnum = des.cols * des.rows;
if (des.type() == CV_8UC3) {
for (int k = 0; k < totalnum; k++)
{
des.data[k] = 255 - des.data[k];
des.data[k+1] = 255 - des.data[k+1];
des.data[k+2] = 255 - des.data[k+2];
}
}
else {
for (int k = 0; k < totalnum; k++)
{
des.data[k] = 255 - des.data[k];
}
}
}
void gray_reverse16(Mat& des)
{
int totalnum = des.cols * des.rows;
for (int i = 0; i < des.rows; i++) {
for (int j = 0; j < des.cols; j++) {
des.at<ushort>(i, j) = 65535 - des.at<ushort>(i, j);
}
}
}
void img_16convert8(Mat& src, Mat& dst) {
if (src.channels() == 3) {
cvtColor(src, src, COLOR_BGR2GRAY);
}
if (src.type() != CV_16UC1) {
dst = src.clone();
return;
}
double mav = 0, miv = 0;
double* maxval = &mav;
double* minval = &miv;
int width = src.cols;
int height = src.rows;
dst = Mat::zeros(height, width, CV_8UC1);
minMaxIdx(src, minval, maxval);
//for (int i = 0; i < width * height; i++) {
// //std::cout << ((double)(src.data[i] - miv) / (mav - miv) * 255) << std::endl;
// dst.data[i] = ((double)(src.data[i]- miv) / (mav - miv) * 255);
//}
for (int i = 0; i < height; i++)
{
const ushort* p_img = src.ptr<ushort>(i);
uchar* p_dst = dst.ptr<uchar>(i);
for (int j = 0; j < width; ++j)
{
p_dst[j] = (p_img[j] - miv) / (mav - miv) * 255;
}
}
}
void draw_rect(Mat& src, vector<Rect> rect_list) {
for (Rect rect : rect_list) {
rectangle(src, rect, Scalar(127), 2);
}
}
int hist(Mat src, int bin) {
// Quantize the hue to 30 levels
// and the saturation to 32 levels
int hbins = 30, sbins = 32;
int histSize = bin;
// hue varies from 0 to 179, see cvtColor
float hranges[] = { 0, 180 };
// saturation varies from 0 (black-gray-white) to
// 255 (pure spectrum color)
float sranges[] = { 0, bin };
const float* ranges = { sranges };
MatND hist;
// we compute the histogram from the 0-th and 1-st channels
int channels[] = { 0 };
cv::calcHist(&src, 1, 0, Mat(), // do not use mask
hist, 1, &histSize, &ranges,
true, // the histogram is uniform
false);
double maxVal = 0;
Point max_local;
cv::minMaxLoc(hist, 0, &maxVal, 0, &max_local);
return max_local.y;
int scale = 10;
int hist_h = 300;//直方图的图像的高
int hist_w = 512; //直方图的图像的宽
int bin_w = hist_w / histSize;//直方图的等级
Mat histImage(hist_h, hist_w, CV_8UC3, Scalar(0, 0, 0));//绘制直方图显示的图像
//绘制并显示直方图
normalize(hist, hist, 0, hist_h, NORM_MINMAX, -1, Mat());//归一化直方图
for (int i = 1; i < histSize; i++)
{
line(histImage, Point((i - 1) * bin_w, hist_h - cvRound(hist.at<float>(i - 1))),
Point((i)*bin_w, hist_h - cvRound(hist.at<float>(i))), Scalar(255, 255, 0), 2, 8, 0);
}
imshow("histImage", histImage);
waitKey(0);
return max_local.y;
}
需要自取