http://poj.org/problem?id=3321
一道搁置了好多天的树状数组
题意:有一颗苹果树,主干是1,下面会有分支,每个分支一个编号一直到n。每个分支只能有一个苹果或没有苹果。初始状态是每个分支一个苹果。示例:
5
1 2
1 3
3 5
3 4
它形成的树是这样的:
现在每个分支的苹果都是1。
下面是两种操作,Q 和C
C j 的意思是如果 j 这个枝子上面有苹果就摘下来,如果没有,那么就会长出新的一个
Q j 就是问 j 这个叉及其下面的苹果总数。如 Q 3,那么答案是3,因为3及其下面的分支共有三个苹果。
思路:更新某个节点的值,询问区间的和。这是一道树状数组的题目,但是树状数组对应的是一维数组。那么应该先把这棵树转化为一维数组。用DFS遍历数的同时记录每个节点的起始和结束的编号。相当于时间戳。比如上图遍历的顺序是 1 2 3 5 4,那么其对应的编号是 1 2 3 4 5。用start[]和end[]记录每个节点的起始时间和结束时间。
start[1] = 1,end[1] = 5(代表1上的树枝是1~5),同理start[2] = 2,end[2] = 2,start[3] = 3,end[3] = 5,start[5] = 4,end[5] = 4,start[4] = 5,end[4] = 5.这就转化为一维数组了。
对于 C j:先判断j分支上是否有苹果,就是利用线段树,只需计算sum( start[j] ) - sum( start[j] -1)是否等于1,是1说明有苹果,更新该点,即减去一个苹果,否则就加上一个苹果,属于单点更新问题,注意是对start[j]更新。
对于 Q j:询问j和j的子树上的苹果数,就是区间和.即 sum( end[j] ) - sum( start[j] - 1)。
最后不能用vector,超时。手写结构体数组。
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 100010;
struct node
{
int data;
struct node *next;
};
struct node edge[maxn];
int start[maxn],end[maxn];
int n,m,dep;
int c[maxn];
int Lowbit(int x)
{
return x&(-x);
}
void dfs(int u)//dfs找每个节点的时间戳,起始时间和结束时间
{
start[u] = ++dep;
struct node * tmp = edge[u].next;
while(tmp)
{
if(start[tmp->data] == 0)
dfs(tmp->data);
tmp = tmp->next;
}
end[u] = dep;
}
int sum(int end)
{
int s = 0;
while(end > 0)
{
s += c[end];
end -= Lowbit(end);
}
return s;
}
void update(int pos, int num)
{
while(pos <= n)
{
c[pos] += num;
pos += Lowbit(pos);
}
}
int main()
{
int u,v;
scanf("%d",&n);
for(int i = 0; i < n-1; i++)
{
scanf("%d %d",&u,&v);
struct node *P = new struct node;
P->data = v;
P->next = edge[u].next;
edge[u].next = P;
struct node *Q = new struct node;
Q->data = u;
Q->next = edge[v].next;
edge[v].next = Q;
}
memset(c,0,sizeof(c));
memset(start,0,sizeof(start));
memset(end,0,sizeof(end));
dep = 0;
dfs(1);
for(int i = 1; i <= n; i++)
{
update(i,1);
}
scanf("%d",&m);
char str[2];
int x,res1,res2,res3;
while(m--)
{
scanf("%s %d",str,&x);
res1 = sum(start[x]);
res2 = sum(start[x]-1);
res3 = sum(end[x]);
if(str[0] == 'C')
{
if(res1 - res2 == 1)
update(start[x],-1);
else update(start[x],1);
}
else
{
printf("%d\n",res3-res2);
}
}
return 0;
}