网上看了很多人的答案,都说最优解是o(n),想到了一种更快的算法,复杂度是O(lgn的平方),就是对左右子树进行二分,找最后一层的最右边那个结点即可:
#include <iostream>
#include <cmath>
#include <stdlib.h>
using namespace std;
struct Node {
Node* left;
Node* right;
~Node() {
if (left) {
delete left;
left = NULL;
}
if (right) {
delete right;
right = NULL;
}
}
};
int GetDepth(Node* root) {
int depth = 0;
while (root) {
depth++;
root = root->left;
}
return depth;
}
void GetMostRightDepthAndIndex(Node* root, int* depth, int* index) {
while (root) {
if (root->right) {
root = root->right;
*index = (*index - 1) * 2;
*index += 2;
*depth += 1;
} else if (root->left) {
root = root->left;
*index = (*index - 1) * 2;
*index += 1;
*depth += 1;
} else {
return;
}
}
}
void GetMostLeftDepthAndIndex(Node* root, int* depth, int* index) {
while (root) {
if (root->left) {
root = root->left;
*index = (*index - 1) * 2;
*index += 1;
*depth += 1;
}
else if (root->left) {
root = root->right;
*index = (*index - 1) * 2;
*index += 2;
*depth += 1;
} else {
return;
}
}
}
void GetNodeNum(Node* root, int cur_depth, int cur_index,
int* node_num, int max_depth) {
if (!root->left) {
if (cur_depth - 1 >= 1) {
*node_num += pow(2, cur_depth - 1) - 1;
}
*node_num += cur_index;
return;
}
int ml_depth = root->left ? cur_depth + 1 : cur_depth;
int ml_index = (cur_index - 1) * 2 + 1;
int mr_depth = root->right ? cur_depth + 1 : cur_depth;
int mr_index = (cur_index - 1) * 2 + 2;
GetMostRightDepthAndIndex(root->left, &ml_depth, &ml_index);
GetMostLeftDepthAndIndex(root->right, &mr_depth, &mr_index);
if (ml_depth == mr_depth) {
if (ml_depth == max_depth) {
GetNodeNum(root->right, cur_depth + 1, (cur_index - 1) * 2 + 2,
node_num, max_depth);
} else if (ml_depth < max_depth) {
GetNodeNum(root->left, cur_depth + 1, (cur_index - 1) * 2 + 1,
node_num, max_depth);
} else {
std::cout << "illegal tree";
exit(1);
}
} else if (ml_depth > mr_depth) {
if (ml_depth == max_depth && mr_depth == max_depth - 1) {
*node_num = pow(2, ml_depth - 1) - 1 + ml_index;
}
} else {
std::cout << "illegal tree";
exit(1);
}
}
int main() {
/* Input: create a */
// depth 1
Node* root = new Node();
// depth 2
root->left = new Node();
root->right = new Node();
// depth 3
root->left->left = new Node();
root->left->right = new Node();
root->right->left = new Node();
root->right->right = new Node();
// depth 4
root->left->left->left = new Node();
root->left->left->right= new Node();
root->left->right->left = new Node();
root->left->right->right = new Node();
root->right->left->left = new Node();
root->right->left->right = new Node();
root->right->right->left = new Node();
root->right->right->right = new Node();
// depth 5
root->left->left->left->left = new Node();
int root_depth = 1;
int root_index = 1;
int max_depth = GetDepth(root);
int node_num = 0;
GetNodeNum(root, root_depth, root_index, &node_num, max_depth);
std::cout << "node_num = " << node_num << endl;
delete root;
}