c++(kdtree)

#include <iostream>
#include <vector>
#include <algorithm>
#include <string>
#include <cmath>
#include <numeric>
#include<time.h>

using namespace std;

struct Point {

	Point(float xx, float yy, float zz) :x(xx), y(yy), z(zz){};
	Point() :x(0), y(0), z(0){};
	float x;
	float y;
	float z;
	bool operator == (const Point &p)const
	{
		return (x == p.x && y == p.y && z == p.z );
	}
	float operator [](int i)
	{
		if (i > 5)
			std::cout << "colorPoint[i]"<< std::endl;
			switch (i)
			{
			case 0:
				return x;
			case 1:
				return y;
			case 2:
				return z;
			default:
				break;
			}
		return 0;
	}
};
//操作符重载
std::ostream& operator <<(std::ostream &os, const std::vector<Point> &P)
{
	for (auto &i : P)
		std::cout << i.x << " " << i.y << " " << i.z << std::endl;
	//cout << endl;
	os.flush();
	return os;
}
std::ostream& operator <<(std::ostream &os, const Point &P)
{

	std::cout << P.x << " " << P.y << " " << P.z;
	//cout << endl;
	os.flush();
	return os;
}

class KdTree {
public:
	Point root;
	KdTree* parent;
	KdTree* leftChild;
	KdTree* rightChild;
	int ClusterIdx;

	int attribute;

	//默认构造函数
	KdTree()
	{
		parent = NULL;
		leftChild = NULL;
		rightChild = NULL;
		ClusterIdx = -1;
	}

	//判断kd树是否只是一个叶子结点
	bool isLeaf()
	{
		return (rightChild == NULL && leftChild == NULL);
	}
	//判断是否是树的根结点
	bool isRoot()
	{
		return (parent == NULL);
	}
	//判断该子kd树的根结点是否是其父kd树的左结点
	bool isLeft()
	{
		return parent->leftChild->root == root;
	}
	//判断该子kd树的根结点是否是其父kd树的右结点
	bool isRight()
	{
		return parent->rightChild->root == root;
	}
};

std::vector<std::vector<float> > Transpose(std::vector<Point> Matrix)
{
	unsigned row = (unsigned)Matrix.size();
	unsigned col = 3;
	std::vector<float> tmp_x, tmp_y, tmp_z;
	for (auto p : Matrix)
	{
		tmp_x.push_back(p.x);
		tmp_y.push_back(p.y);
		tmp_z.push_back(p.z);
	}
	std::vector<std::vector<float> > Trans;
	Trans.push_back(tmp_x);
	Trans.push_back(tmp_y);
	Trans.push_back(tmp_z);
	return Trans;
}

float findMiddleValue(std::vector<float> vec)
{
	sort(vec.begin(), vec.end());
	auto pos = vec.size() / 2;
	return vec[pos];
}

float computeVar(std::vector<float> A)
{
	
	float sum = (float)std::accumulate(A.begin(), A.end(), 0);
	float mean = sum / A.size();

	float accum = 0;
	for (auto &i : A)
		accum += (i - mean)*(i - mean);
	float var = accum / (A.size() - 1);
	return var;
}

int findSplitAttribute(std::vector<std::vector<float> > transData)
{
	std::vector<int> vars;
	for (auto i : transData)
		vars.push_back(computeVar(i));
	auto biggest = std::max_element(std::begin(vars), std::end(vars));
	int idx = std::distance(std::begin(vars), biggest);
	return idx;
}
void buildKdTree(KdTree* tree, std::vector<Point> data, unsigned depth)
{

	//样本的数量
	int samplesNum = data.size();
	
	//终止条件
	if (samplesNum == 0)
	{
		return;
	}
	if (samplesNum == 1)
	{
		tree->root = data[0];
		return;
	}
	//样本的维度
	std::vector<std::vector<float> > transData = Transpose(data);
	//选择切分属性
	//unsigned splitAttribute = depth % k;
	unsigned splitAttribute = findSplitAttribute(transData);
	std::vector<float> splitAttributeValues = transData[splitAttribute];
	//选择切分值
	float splitValue = findMiddleValue(splitAttributeValues);
	//cout << "splitValue" << splitValue  << endl;

	// 根据选定的切分属性和切分值,将数据集分为两个子集
	std::vector<Point> subset1;
	std::vector<Point> subset2;
	for (unsigned i = 0; i < samplesNum; ++i)
	{
		if (splitAttributeValues[i] == splitValue)
			tree->root = data[i];
		else
		{
			if (splitAttributeValues[i] < splitValue)
				subset1.push_back(data[i]);
			else
				subset2.push_back(data[i]);
		}
	}


	tree->attribute = splitAttribute;
	if (!subset1.empty())
	{
		tree->leftChild = new KdTree;
		tree->leftChild->parent = tree;
		buildKdTree(tree->leftChild, subset1, depth + 1);
	}
	if (!subset2.empty())
	{
		tree->rightChild = new KdTree;
		tree->rightChild->parent = tree;
		buildKdTree(tree->rightChild, subset2, depth + 1);
	}
}

void printKdTree(KdTree *tree, unsigned depth)
{
	for (unsigned i = 0; i < depth; ++i)
		cout << "\t";

	cout << tree->root << endl;
	if (tree->leftChild == NULL && tree->rightChild == NULL)//叶子节点
		return;
	else //非叶子节点
	{
		if (tree->leftChild != NULL)
		{
			for (unsigned i = 0; i < depth + 1; ++i)
				cout << "\t";
			cout << " left:";
			printKdTree(tree->leftChild, depth + 1);
		}

		cout << endl;
		if (tree->rightChild != NULL)
		{
			for (unsigned i = 0; i < depth + 1; ++i)
				cout << "\t";
			cout << "right:";
			printKdTree(tree->rightChild, depth + 1);
		}
		cout << endl;
	}
}

void getKdTreeValue(KdTree *tree, std::vector<Point> &out, std::vector<int> &idx)
{
	idx.push_back(tree->ClusterIdx);
	out.push_back(tree->root);
	if (tree->leftChild == NULL && tree->rightChild == NULL)//叶子节点
		return;
	else //非叶子节点
	{
		if (tree->leftChild != NULL)
			getKdTreeValue(tree->leftChild, out, idx);
		if (tree->rightChild != NULL)
			getKdTreeValue(tree->rightChild, out, idx);
	}
}
bool searchUndefinedIdx(KdTree *tree, KdTree* & newTree)
{
	//if (tree == NULL)
	//	return false;
	if (tree->ClusterIdx == -1)
	{
		newTree = tree;
		return true;
	}
	if (tree->leftChild == NULL && tree->rightChild == NULL)//叶子节点
		return false;
	//else
	//	return searchUndefinedIdx(tree->leftChild, newTree) | 
	//	searchUndefinedIdx(tree->rightChild, newTree);
	bool left = false;
	bool right = false;
	if (tree->leftChild != NULL)
		left = searchUndefinedIdx(tree->leftChild, newTree);
	if (tree->rightChild != NULL)
		right = searchUndefinedIdx(tree->rightChild, newTree);
	return (right | left);

}
float measureDistance(Point point1, Point point2, unsigned method)
{
	switch (method)
	{
	case 0://欧氏距离
	{
		float res = pow((point1.x - point2.x), 2) + pow((point1.y - point2.y), 2) + pow((point1.z - point2.z), 2);
		return sqrt(res);
	}
	case 1://曼哈顿距离
	{

		float res = abs(point1.x - point2.x) + abs(point1.y - point2.y) + abs(point1.z - point2.z);
		return res;
	}
	case 2://点云水平距离
	{
		float res = pow((point1.x - point2.x), 2) + pow((point1.y - point2.y), 2) + 0.1*pow((point1.z - point2.z), 2);
		return sqrt(res);
	}
	default:
	{
		cerr << "Invalid method!!" << endl;
		return -1;
	}
	}
}

void searchRadiusNeighbor(KdTree *tree, Point goal, float radius, std::vector<Point> &out)
{
	KdTree* currentTree = tree;
	//这里的 currentNearest分割点的坐标
	Point currentNearest = currentTree->root;
	// 点云自定义距离
	float currentDistance = measureDistance(goal, currentNearest, 2);
	//如果距离符合,输出
	if (currentDistance <= radius)
		out.push_back(currentNearest);
	// 是否为叶节点
	if (!currentTree->isLeaf())
	{
		// 当前分割属性
		unsigned index = currentTree->attribute;
		float districtDistance = -1;
		switch (index)
		{
		case 0:
		{
			districtDistance = goal.x - currentTree->root.x;
			break;
		}
		case 1:
		{
			districtDistance = goal.y - currentTree->root.y;
			break;
		}case 2:
		{
			districtDistance = goal.z - currentTree->root.z;
			break;
		}
		default:
			break;
		}
		if (districtDistance == -1)
			cout << "出错districtDistance == -1" << endl;
		if (abs(districtDistance) <= radius)
		{
			if (currentTree->leftChild != NULL)
				searchRadiusNeighbor(currentTree->leftChild, goal, radius, out);
			if (currentTree->rightChild != NULL)
				searchRadiusNeighbor(currentTree->rightChild, goal, radius, out);
		}
		else
		{
			if (currentTree->rightChild != NULL && districtDistance < 0)
				searchRadiusNeighbor(currentTree->leftChild, goal, radius, out);
			else if (currentTree->rightChild != NULL && districtDistance > 0)
				searchRadiusNeighbor(currentTree->rightChild, goal, radius, out);
		}
	}
}

void searchRadiusNeighborConditional(KdTree *tree, KdTree *target_tree, float radius, std::vector<KdTree*> &out)
{
	KdTree* currentTree = tree;
	Point currentNearest = currentTree->root;
	Point goal = target_tree->root;

	// 点云自定义距离
	if (currentTree->ClusterIdx == -1)
	{
		float currentDistance = measureDistance(goal, currentNearest, 0);
		if (currentDistance <= radius)
		{
			out.push_back(currentTree);
			currentTree->ClusterIdx = target_tree->ClusterIdx;
		}
	}


	// 是否为叶节点
	if (!currentTree->isLeaf())
	{
		// 当前分割属性
		unsigned index = currentTree->attribute;
		float districtDistance = goal[index] - currentTree->root[index];

		if (abs(districtDistance) <= radius)
		{
			if (currentTree->leftChild != NULL)
				searchRadiusNeighborConditional(currentTree->leftChild, target_tree, radius, out);
			if (currentTree->rightChild != NULL)
				searchRadiusNeighborConditional(currentTree->rightChild, target_tree, radius, out);
		}
		else
		{
			if (currentTree->leftChild != NULL && districtDistance < 0)
				searchRadiusNeighborConditional(currentTree->leftChild, target_tree, radius, out);
			else if (currentTree->rightChild != NULL && districtDistance > 0)
				searchRadiusNeighborConditional(currentTree->rightChild, target_tree, radius, out);
		}
	}
}



int main()
{
	//读取 .bin 文件 (这里可以写一个接口)
	std::FILE *pFile = fopen("000000.bin", "rb");
	fseek(pFile, 0, SEEK_END);    // file pointer goes to the end of the file  
	long fileSize = ftell(pFile); // file size  
	rewind(pFile);                // rewind file pointer to the beginning 
	float *rawData = new float[fileSize];
	fread(rawData, sizeof(float), fileSize / sizeof(float), pFile);
	long number_of_points = fileSize / 4 / sizeof(float);//the number of  points
	std::vector<Point> pcd;

	for (size_t i = 0; i < number_of_points; i++)
	{
		Point temp;
		temp.x = *(rawData+i*4);
		temp.y = *(rawData+i*4+1);
		temp.z = *(rawData+i*4+2);
		pcd.push_back(temp);
	}
	

	KdTree my_kdtree;
	unsigned dee=0;


	clock_t start, end;
	start = clock();
	buildKdTree(&my_kdtree, pcd, dee);
	end = clock();		//程序结束用时
	double endtime = (double)(end - start) / CLOCKS_PER_SEC;
	cout << "data size" << pcd.size() << endl;
	cout << "Total time(build tree):" << endtime * 1000 << "ms" << endl;	//ms为单位




	return 0;
}

 

 kdtree 的c++ 实现,12w的点大概要500ms,速度还是很慢,后面再看能不能改进(有同学是50ms)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值