题目:http://acm.hdu.edu.cn/showproblem.php?pid=3887
大意:求在当前点的所有后代中,后代点的序号大小<当前点的序号大小,并统计他们的个数,这就是f[i]。
思路:dfs序+树状数组/线段树
dfs序:https://www.cnblogs.com/stxy-ferryman/p/7741970.html
dfs序:dfs是深度优先的,所以对于一个点,它会先遍历完它的所有子节点,再去遍历他的兄弟节点以及其他所以对于一棵树的dfs序来说,这个点和他所有的子节点会被存储在连续的区间之中。
怎样统计后代序号比它小的个数之和?
首先经dfs遍历之后每个点以及他的子节点都被保存在一个连续的区间内,原来的序号变成了该连续区间的下标,找比它序号小的就直接往前面找并且前面点的区间必须包含在该区间内(dfs后一个点的子树所在的区间一定是在该点的区间内)
已根节点7为例:dfs后in[7]=1,out[7]=15,节点为7的区间包含所有的点,所以直接在前面找,7钱面有6个点,所以f[7]=6;
节点10:in[10]=2,out[10]=5,子区间有三个,in[14]=3,out[14]=5; in[2]=4,out[2]=4;
in[13]=5,out[13]=5;但是10前面只有节点2的区间在节点10内。所以f[10]=1;
上代码:
树状数组:
#pragma comment(linker,"/STACK:1024000000,1024000000")
#include<cstdio>
#include<cstring>
#define max(a,b) a>b? a:b
#define len 100005
using namespace std;
int head[len],n,root,k,cnt,in[len],out[len];
struct node
{
int j,next;
}s[len*2];
struct text
{
int l,r,sum;
}ans[len];
int c[len];
void add(int a,int b)
{
s[k].j=b;
s[k].next=head[a];
head[a]=k++;
}
void dfs(int x,int pre)
{
in[x]=++cnt;
for(int i=head[x];i!=-1;i=s[i].next)
{
int j=s[i].j;
if(j!=pre) dfs(j,x);
}
out[x]=cnt;
}
int lowbit(int x)
{
return x&(-x);
}
void update(int x,int p)
{
for(int i=x;i<=n;i+=lowbit(i)) c[i]+=p;
return ;
}
int sum(int x)
{
int sum=0;
for(int i=x;i>0;i-=lowbit(i)) sum+=c[i];
return sum;
}
int main()
{
while(scanf("%d%d",&n,&root)!=EOF&&n&&root)
{
cnt=0;
k=0;
int a,b;
memset(head,-1,sizeof(head));
memset(c,0,sizeof(c));
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&a,&b);
add(a,b);
add(b,a);
}
dfs(root,-1);
for(int i=1;i<=n;i++)
{
printf("%d%s",sum(out[i])-sum(in[i]-1),i==n? "\n":" ");
update(in[i],1);
}
}
return 0;
}
线段树:
dfs我用的结构体结果爆栈了,借用vector
大神原代码:https://www.cnblogs.com/sosi/p/3722458.html
#pragma comment(linker,"/STACK:1024000000,1024000000")
#include<cstdio>
#include<cstring>
#include<vector>
#define max(a,b) a>b? a:b
#define len 200005
using namespace std;
int n,root,k,cnt,in[len],out[len];
bool vis[len];
vector<vector<int> > t(len);
struct text
{
int l,r,sum;
}ans[len<<2];
void dfs(int x)
{
vis[x]=1;
in[x]=++cnt;
for(int i=0;i<t[x].size();i++)
{
if(!vis[t[x][i]]) dfs(t[x][i]);
}
out[x]=cnt;
}
void pushup(int i)
{
ans[i].sum=ans[i*2].sum+ans[i*2+1].sum;
}
void build(int l,int r,int i)
{
ans[i].l=l;
ans[i].r=r;
ans[i].sum=0;
if(ans[i].l==ans[i].r) return ;
int mid=(ans[i].l+ans[i].r)/2;
build(l,mid,i*2);
build(mid+1,r,i*2+1);
}
void update(int l,int r,int i)
{
if(ans[i].l==l&&ans[i].r==r)
{
ans[i].sum=1;
return ;
}
int mid=(ans[i].l+ans[i].r)/2;
if(mid>=r) update(l,r,i*2);
else if(mid<l) update(l,r,i*2+1);
else
{
update(l,mid,i*2);
update(mid+1,r,i*2+1);
}
pushup(i);
}
int query(int l,int r,int i)
{
if(ans[i].l==l&&ans[i].r==r) return ans[i].sum;
int mid=(ans[i].l+ans[i].r)/2;
if(mid>=r) query(l,r,i*2);
else if(mid<l) query(l,r,i*2+1);
else return query(l,mid,i*2)+query(mid+1,r,i*2+1);
}
int main()
{
while(scanf("%d%d",&n,&root)!=EOF&&n&&root)
{
cnt=0;
k=0;
int a,b;
memset(vis,0,sizeof(vis));
for(int i=1;i<=n;i++) t[i].clear();
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&a,&b);
t[a].push_back(b);
t[b].push_back(a);
}
dfs(root);
build(1,n,1);
for(int i=1;i<=n;i++)
{
printf("%d%s",query(in[i],out[i],1),i==n? "\n":" ");
update(in[i],in[i],1);
}
}
return 0;
}