树状数组看了我大半天,终于搞懂了,太菜鸡了。
树状数组
/*
树状数组,关键是怎么实现,用两个结构体,一个结构体edge+一个一维数组一起来存储边与节点之间的关系。
就是类似一个邻接表,进行存储。
另一个结构体apple就是储存每个节点的后序序列的区间。
*/
#include<iostream>
#include<cstring>
#include<cstdio>
#define mm(a,b) memset(a,b,sizeof(a))
using namespace std;
const int maxn=100010;
struct node1
{
int next;
int tail;
}edge[maxn];
struct node2
{
int r;
int l;
}apple[maxn];
int s[maxn],cnt,c[maxn],a[maxn];
void dfs(int u)
{
apple[u].l=cnt;
for(int i=s[u];i!=-1;i=edge[i].next)
dfs(edge[i].tail);
apple[u].r=cnt++;
}
inline int lowbit(int x)
{
return x&(-x);
}
void change(int x)
{
if(a[x])
for(int i=x;i<cnt;i+=lowbit(i))
c[i]++;
else
for(int i=x;i<cnt;i+=lowbit(i))
c[i]--;
}
int sum(int x)
{
int i,res=0;
for(int i=x;i>0;i-=lowbit(i))
res+=c[i];
return res;
}
int main()
{
char str[3];
int n,m,t1,t2,t;
scanf("%d",&n);
mm(s,-1);
mm(c,0);
memset(apple,0,sizeof(apple[0])*(n-1));
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&t1,&t2);
edge[i].tail=t2;
edge[i].next=s[t1];
s[t1]=i;
}
cnt=1;
dfs(1);
// printf("123\n");
scanf("%d",&m);
for(int i=1;i<=n;i++)
{
a[i]=1;
change(i);
}
while(m--)
{
scanf("%s%d",str,&t);
if(str[0]=='Q')
printf("%d\n",sum(apple[t].r)-sum(apple[t].l-1));
else
{
a[apple[t].r]=(a[apple[t].r]+1)%2;
change(apple[t].r);
}
}
return 0;
}
线段树
#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
#define mm(a,b) memset(a,b,sizeof(a))
using namespace std;
const int maxn=100010;
vector<vector<int> > v(maxn);
int ll[maxn],rr[maxn],tree[4*maxn],cnt;
void dfs(int u)
{
ll[u]=++cnt;
for(int i=0;i<v[u].size();i++)
{
dfs(v[u][i]);
}
rr[u]=cnt;
}
void build(int p,int l,int r)
{
if(l==r)
{
tree[p]=1;
return ;
}
int mid=(l+r)>>1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
tree[p]=tree[p<<1]+tree[p<<1|1];
}
void update(int p,int l,int r,int x)
{
if(l==r)
{
tree[p]^=1;
return;
}
int mid=(l+r)>>1;
if(x<=mid)
update(p<<1,l,mid,x);
else
update(p<<1|1,mid+1,r,x);
tree[p]=tree[p<<1]+tree[p<<1|1];
}
int find(int p,int l,int r,int x,int y)
{
if(x<=l&&y>=r)
return tree[p];
int mid=(l+r)>>1;
if(y<=mid)
return find(p<<1,l,mid,x,y);
else if(x>mid)
return find(p<<1|1,mid+1,r,x,y);
else
return find(p<<1,l,mid,x,mid)+find(p<<1|1,mid+1,r,mid+1,y);
}
int main()
{
int n;
scanf("%d",&n);
mm(ll,0);
mm(rr,0);
cnt=0;
for(int i=0;i<=n;i++)
v[i].clear();
int t1,t2;
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&t1,&t2);
v[t1].push_back(t2);
}
dfs(1);
build(1,1,n);
int m,t;
scanf("%d",&m);
char str[3];
for(int i=0;i<m;i++)
{
scanf("%s%d",str,&t);
if(str[0]=='Q')
{
printf("%d\n",find(1,1,n,ll[t],rr[t]));
}
else
update(1,1,n,ll[t]);
}
return 0;
}