一、题目
点此看题
题意:
给出一棵树和若干个路径,问对于每个点是多少个路径的
w
i
+
1
w_i+1
wi+1个点。
二、解法
先来推一推柿子,分类讨论一下:
- i i i在 ( u , l c a ) (u,lca) (u,lca)的路径上,如果 d e p [ u ] = d e p [ i ] + w [ i ] dep[u]=dep[i]+w[i] dep[u]=dep[i]+w[i],那么这条路径对 i i i有 1 1 1的贡献。
- i i i在 ( l c a , v ) (lca,v) (lca,v)的路径上,如果 w [ i ] − d i s ( u , l c a ) = d e p [ i ] − d e p [ l c a ] w[i]-dis(u,lca)=dep[i]-dep[lca] w[i]−dis(u,lca)=dep[i]−dep[lca],也就是 d e p [ i ] − w [ i ] = 2 ∗ d e p [ l c a ] − d e p [ u ] dep[i]-w[i]=2*dep[lca]-dep[u] dep[i]−w[i]=2∗dep[lca]−dep[u],那么这条路径对 i i i有 1 1 1的贡献。
这像是树上路径问题,可以考虑树上差分。
我们对于每个点开个桶, t [ u ] [ ] t[u][\thickspace] t[u][],然后直接修改桶,最后大法师一遍统计。
但是这样空间肯定是要爆炸的,所以我们用线段树的形式表示桶,采取动态开点的方式,最后合并的时候也有所考究,如果需要合并的两个线段树为 ( x , y ) (x,y) (x,y),那么如果它们其中有一个为空都直接修改编号(省时省力),否则暴力合并。
这样做的时间复杂度为 O ( n log n ) O(n\log n) O(nlogn),问题本质是讲 n n n个大小为 1 1 1的线段树合并成一颗,由势函数分析可保证复杂度,不展开,有两个细节,代码中会有注释。
#include <cstdio>
#include <iostream>
using namespace std;
const int MAXN = 300005;
int read()
{
int x=0,flag=1;char c;
while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^'0'),c=getchar();
return x*flag;
}
int n,m,tot,cnt,f[MAXN],dep[MAXN],fa[MAXN][20];
int rt[MAXN][2],w[MAXN],ans[MAXN];
struct edge
{
int v,next;
}e[2*MAXN];
struct tree
{
int ls,rs,v;
}tr[MAXN*80];
void dfs(int u,int par)
{
fa[u][0]=par;
dep[u]=dep[par]+1;
for(int i=1;i<=19;i++)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==par) continue;
dfs(v,u);
}
}
int get(int u,int v)//倍增lca
{
if(dep[u]<dep[v]) swap(u,v);
for(int i=19;i>=0;i--)
if(dep[fa[u][i]]>=dep[v])
u=fa[u][i];
if(u==v) return u;
for(int i=19;i>=0;i--)
if(fa[u][i]^fa[v][i])
u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
void add(int &x,int l,int r,int f,int v)
{
if(!x) x=++cnt;//动态开点
if(l==r) {tr[x].v+=f;return ;}
int mid=(l+r)>>1;
if(mid>=v)
add(tr[x].ls,l,mid,f,v);
else
add(tr[x].rs,mid+1,r,f,v);
tr[x].v=tr[tr[x].ls].v+tr[tr[x].rs].v;
}
int query(int x,int l,int r,int v)
{
if(!x || v>r) return 0;//细节1:要判,因为dep+w可能大于n
if(l==r) return tr[x].v;
int mid=(l+r)>>1;
if(mid>=v)
return query(tr[x].ls,l,mid,v);
return query(tr[x].rs,mid+1,r,v);
}
void merge(int &x,int y)
{
if(!x||!y) {x=x+y;return ;}
tr[x].v+=tr[y].v;
merge(tr[x].ls,tr[y].ls);
merge(tr[x].rs,tr[y].rs);
}
void count(int u)
{
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==fa[u][0]) continue;
count(v);
merge(rt[u][0],rt[v][0]);
merge(rt[u][1],rt[v][1]);
}
ans[u]+=query(rt[u][0],1,n,dep[u]+w[u]);
ans[u]+=query(rt[u][1],1,2*n,dep[u]-w[u]+n);
//细节2:一定要在这里存答案,因为修改节点编号的话可能会在以后值被再次修改
}
int main()
{
n=read();m=read();
for(int i=2;i<=n;i++)
{
int u=read(),v=read();
e[++tot]=edge{v,f[u]},f[u]=tot;
e[++tot]=edge{u,f[v]},f[v]=tot;
}
for(int i=1;i<=n;i++)
w[i]=read();
dfs(1,0);
for(int i=1;i<=m;i++)
{
int u=read(),v=read();
int lca=get(u,v),t=2*dep[lca]-dep[u]+n;//避免负数
add(rt[u][0],1,n,1,dep[u]);
add(rt[fa[lca][0]][0],1,n,-1,dep[u]);
add(rt[v][1],1,2*n,1,t);
add(rt[lca][1],1,2*n,-1,t);
}
count(1);
for(int i=1;i<=n;i++)
{
printf("%d ",ans[i]);
}
}