题目描述
给定一个数组nums,返回一个计数数组count,count[i]表示nums中第i个右边有多少个数小于nums[i]
Example:
nums = [5, 2, 6, 1]
输出[2,1,1,0]
分析解答
此题不难给出O(N^2)的算法,先穷举nums中每个位置i,再穷举右边的数计算有多少个小于nums[i]。难点在于利用数据结构进行优化从而降低时间复杂度。线段树(segment tree)和平衡树(Balanced Binary Tree)是两种可以使用的数据结构。
线段树的每个节点表示一段区间,记录这个区间的某些信息,其基本思想是把区间一分为二,二分为四。。。直到不可再分(因此叶子节点的区间只包含一个数),如此可以把任意区间表示成log(区间大小)个子区间的拼接,以降低查询时间复杂度。在本题中,假设nums中的数字范围在0到maxnum之间,那么建树的区间为[0,maxnum](也就是根节点所表示的区间)。每个节点记录其表示区间内的数字个数。本题涉及两种线段树基本操作:插入和查询。插入操作把nums[i]插入到线段树相应位置,同时对所有经过的区间的sum值进行累加;查询操作需要查询区间[0,nums[i]-1]所包含的数字个数,利用已经建好的线段树把查询区间分割为若干个节点所表示的区间,统计并返回这些节点的sum值之和。
平衡树用途更广,代码复杂度也更高,是一种保持叶子节点深度平衡的二叉搜索树,有多种方法实现,可以参照LeetCode。
参考程序
1.线段树 Segment Tree
class Solution {
public:
struct SegmentTreeNode {
SegmentTreeNode* left, *right;
int start, end;
int count;
SegmentTreeNode(int start, int end, int count)
: start(start), end(end), count(count) {
left = right = NULL;
}
};
SegmentTreeNode* build(int start, int end) {
if (start > end) return NULL;
SegmentTreeNode* root = new SegmentTreeNode(start, end, 0);
if (start != end) {
int mid = start + (end - start) / 2;
root->left = build(start, mid);
root->right = build(mid+1, end);
}
return root;
}
int querySegmentTree(SegmentTreeNode* root, int start, int end) {
if (root->start == start && root->end == end)
return root->count;
int leftcount = 0, rightcount = 0;
int mid = root->start + (root->end - root->start) / 2;
// left half part
if (start <= mid) {
if (mid < end) {
leftcount = querySegmentTree(root->left, start, mid);
} else {
leftcount = querySegmentTree(root->left, start, end);
}
}
// right half part
if (mid < end) {
if (start <= mid) {
rightcount = querySegmentTree(root->right, mid+1, end);
} else {
rightcount = querySegmentTree(root->right, start, end);
}
}
return leftcount + rightcount;
}
void modifySegmentTree(SegmentTreeNode* root, int index, int value) {
if (root->start == index && root->end == index) {
root->count += value;
return ;
}
int mid = root->start + (root->end - root->start) / 2;
if (root->start <= index && index <= mid) {
modifySegmentTree(root->left, index, value);
}
if (mid < index && index <= root->end) {
modifySegmentTree(root->right, index, value);
}
root->count = root->left->count + root->right->count;
}
vector<int> countSmaller(vector<int>& nums) {
vector<int> ret;
SegmentTreeNode* root = build(-1000, 10000);
for (int i=nums.size()-1; i>=0; i--) {
int ans = querySegmentTree(root, -1000, nums[i]-1);
modifySegmentTree(root, nums[i], 1);
ret.push_back(ans);
}
reverse(ret.begin(), ret.end());
return ret;
}
};
注意事项
之前写的线段树居然不能有负区间,debug了好长时间,发现原来是区间中点计算有问题,应该类似于int c = a + (b - a) / 2
2.二叉搜索树 Binary Search Tree
每个节点保存sum(左子树节点个数),dup(副本个数)。当插入一个数时,比它小的数的个数就是沿着树向右转时dup和sum之和,详情请见LeetCode
class Solution {
public:
struct Node {
Node* left, *right;
int val, sum, dup = 1;
Node(int v, int s) : val(v), sum(s) {
left = right = NULL;
}
};
vector<int> countSmaller(vector<int>& nums) {
vector<int> ret(nums.size());
Node* root = NULL;
for (int i=nums.size()-1; i>=0; i--) {
insert(nums[i], root, ret[i], 0);
}
return ret;
}
void insert(int num, Node*& root, int& ret, int preSum) {
if (root == NULL) {
root = new Node(num, 0);
ret = preSum;
} else if (root->val > num) {
root->sum++;
insert(num, root->left, ret, preSum);
} else if (root->val < num) {
insert(num, root->right, ret, preSum + root->dup + root->sum);
} else {
root->dup++;
ret = preSum + root->sum;
}
}
};
迭代版
class Solution {
public:
struct Node {
Node* left, *right;
int val, sum = 0, dup = 0;
Node(int v) : val(v) {
left = right = NULL;
}
};
vector<int> countSmaller(vector<int>& nums) {
vector<int> ret(nums.size());
if (nums.size() == 0) return vector<int>();
Node* root = new Node(nums[nums.size()-1]);
for (int i=nums.size()-1; i>=0; i--) {
ret[i] = insert(root, nums[i]);
}
return ret;
}
int insert(Node* root, int num) {
int ret = 0;
while (root->val != num) {
if (root->val > num) {
root->sum++;
if (root->left == NULL)
root->left = new Node(num);
root = root->left;
} else {
ret += root->dup + root->sum;
if (root->right == NULL)
root->right = new Node(num);
root = root->right;
}
}
root->dup++;
return ret + root->sum;
}
};