这题是NOIP2016 Day1 T2.。。从昨天晚上搞到今天上午。。我现在很慌。。看题解看了半天看不懂,最后还是wcx daolao讲懂的。
看到树上的路径,很容易想到拆成两条路径,即起点到LCA和LCA到终点。
对于起点S到LCA的,要让位于i点的观察员看到,则需满足deep[i]+w[i]=deep[s],对于每一个观察员来说,deep[i]+w[i]为定值,所以只需在i的子树中找到满足的点即可,考虑到对i的子树进行操作,我们想到dfs序,in[i]到out[i]这段区间即为i的子树。我们可以对每个深度开一棵线段树,对于每个观察员i,只需在deep[i]+w[i]的线段树中统计玩家数量。那么到底如何统计呢。。我们可以用树上差分。对于每一条路径,将起点权值加1,将LCA的父亲的权值减1,这样的前缀和即为在这棵深度线段树上能被观察到的人数。。
对于LCA到终点的路径,只需满足deep[s]-2*deep[lca(s,t)]=w[i]-deep[i],所以只需在深度为w[i]-deep[i]的线段树中寻找。。
注意开线段树时要动态开点,不然果断MLE。而且每次求完上升路径时,要清空线段树
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cstring>
using namespace std;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')
f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<3)+(x<<1)+ch-'0';
ch=getchar();
}
return x*f;
}
struct bian
{
int qi,zhong,next;
};
bian c[600010];
bool vis[300010];
int n,m,x,y,jishu=0,jishu1=0,jishu2=0;
int in[300010],ans[300010],out[300010],zu[300010],xuan[300010],ying[300010],sum[7000000],lc[7000000],rc[7000000],root[7000000];
int w[300010],head[300010],deep[300010],size[300010],fa[300010],p[300010][20];
void add(int x,int y)
{
c[++jishu].qi=x;
c[jishu].zhong=y;
c[jishu].next=head[x];
head[x]=jishu;
}
void init()
{
for(int j=0;(1<<j)<=n;++j)
for(int i=1;i<=n;++i)
p[i][j]=-1;
for(int i=1;i<=n;++i)
p[i][0]=fa[i];
for(int j=1;(1<<j)<=n;++j)
for(int i=1;i<=n;++i)
if(p[i][j-1]!=-1)
p[i][j]=p[p[i][j-1]][j-1];
}
void dfs(int u,int f,int t)
{
deep[u]=t;
fa[u]=f;
for(int i=head[u];i;i=c[i].next)
if(c[i].zhong!=f)
dfs(c[i].zhong,u,t+1);
}
int lca(int a,int b)
{
int i,j;
if(deep[a]<deep[b])
swap(a,b);
for(i=0;(1<<i)<=deep[a];++i);
i--;
for(j=i;j>=0;--j)
if(deep[a]-(1<<j)>=deep[b])
a=p[a][j];
if(a==b) return a;
for(j=i;j>=0;--j)
if(p[a][j]!=-1&&p[a][j]!=p[b][j])
{
a=p[a][j];
b=p[b][j];
}
return fa[a];
}
void dfs1(int x)
{
in[x]=++jishu1;
vis[x]=1;
for(int i=head[x];i;i=c[i].next)
if(!vis[c[i].zhong])
dfs1(c[i].zhong);
out[x]=jishu1;
}
void clear()
{
jishu2=0;
memset(lc,0,sizeof(lc));
memset(rc,0,sizeof(rc));
memset(sum,0,sizeof(sum));
memset(root,0,sizeof(root));
}
void update(int l,int r,int i,int w,int &now)
{
if(!i) return ;
if(!now) now=++jishu2;
sum[now]+=w;
if(l==r) return ;
int mid=(l+r)>>1;
if(i<=mid)
update(l,mid,i,w,lc[now]);
else
update(mid+1,r,i,w,rc[now]);
}
int query(int l,int r,int L,int R,int i)
{
if(!i) return 0;
if(L<=l&&r<=R) return sum[i];
int mid=(l+r)>>1;
if(R<=mid) return query(l,mid,L,R,lc[i]);
else if(L>mid) return query(mid+1,r,L,R,rc[i]);
else return query(l,mid,L,mid,lc[i])+query(mid+1,r,mid+1,R,rc[i]);
}
int main()
{
n=read();m=read();
for(int i=1;i<n;++i)
{
x=read();y=read();
add(x,y);add(y,x);
}
dfs(1,-1,0);
dfs1(1);
init();
for(int i=1;i<=n;++i)
w[i]=read();
for(int i=1;i<=m;++i)
{
xuan[i]=read();ying[i]=read();
zu[i]=lca(xuan[i],ying[i]);
}
for(int i=1;i<=m;++i)
{
int shen=deep[xuan[i]];
update(1,n,in[xuan[i]],1,root[shen]);
update(1,n,in[p[zu[i]][0]],-1,root[shen]);
}
for(int i=1;i<=n;++i)
ans[i]+=query(1,n,in[i],out[i],root[deep[i]+w[i]]);
clear();
for(int i=1;i<=m;++i)
{
int shen=deep[xuan[i]]-2*deep[zu[i]]+2*n;
update(1,n,in[ying[i]],1,root[shen]);
update(1,n,in[zu[i]],-1,root[shen]);
}
for(int i=1;i<=n;++i)
ans[i]+=query(1,n,in[i],out[i],root[w[i]-deep[i]+2*n]);
for(int i=1;i<=n;++i)
{
if(i!=n)
printf("%d ",ans[i]);
else
printf("%d",ans[i]);
}
return 0;
}