曾今考过的一场模拟赛中的题目,当时的我只会打暴力,还是那种非常低级的暴力,连 80 分都没有……
这题的正解是 LCA +树上差分。观察题目给出的子问题,20%的数据起点为1,20%的数据终点为1。根据这些信息我们想到了什么?没错,把一条路径分成 st 到 lca 的路径和 lca 到 ed 的路径。
考虑一个在
st
到
lca
的路径上的观察员
i
观察到了这个人,可以得到等式
突然发现上面得到的两个等式的左端都为常数,和观察员无关,那么是否我们可以用差分的思想来统计答案呢?
结论是可以的,那么在对所有给出的路径打完标记后直接遍历一遍整棵树就可以得到每个点的答案了。
p.s.在 lca 到 ed 上打的标记可能为负,需要把标记数组整体向正方向平移一个常数的单位,保证下标都为正。
又p.s.BZOJ上行末不能有空格,否则会PE……(太坑)
附上AC代码:
#include <cstdio>
#include <cctype>
#include <algorithm>
#include <vector>
using namespace std;
const int N=3e5+10;
struct side{
int to,nt;
}s[N<<1];
vector <int> a[N],d1[N],d2[N];
int n,m,x,y,num,h[N],w[N],lca,len,c[N],c1[N<<1],c2[N<<1],ans[N];
int d[N],f[N],sz[N],hs[N],top[N];
inline char nc(){
static char ch[100010],*p1=ch,*p2=ch;
return p1==p2&&(p2=(p1=ch)+fread(ch,1,100010,stdin),p1==p2)?EOF:*p1++;
}
inline void read(int &a){
static char c=nc();int f=1;
for (;!isdigit(c);c=nc()) if (c=='-') f=-1;
for (a=0;isdigit(c);a=a*10+c-'0',c=nc());
a*=f;return;
}
inline void add(int x,int y){
s[++num]=(side){y,h[x]},h[x]=num;
s[++num]=(side){x,h[y]},h[y]=num;
}
inline void so1(int x,int fa){
d[x]=d[f[x]=fa]+1,sz[x]=1;
for (int i=h[x]; i; i=s[i].nt)
if (s[i].to!=fa){
so1(s[i].to,x),sz[x]+=sz[s[i].to];
if (sz[s[i].to]>sz[hs[x]]) hs[x]=s[i].to;
}
return;
}
inline void so2(int x,int fa){
top[x]=fa;
if (hs[x]) so2(hs[x],fa);
for (int i=h[x]; i; i=s[i].nt)
if (s[i].to!=f[x]&&s[i].to!=hs[x]) so2(s[i].to,s[i].to);
return;
}
inline int query(int x,int y){
for (int fx=top[x],fy=top[y]; fx!=fy; x=f[fx],fx=top[x])
if (d[fx]<d[fy]) swap(fx,fy),swap(x,y);
return d[x]<d[y]?x:y;
}
inline void work(int x,int fa){
int cnt1=c1[d[x]+w[x]],cnt2=c2[d[x]-w[x]+N];
c1[d[x]]+=c[x];
for (int i=0; i<a[x].size(); ++i) ++c2[a[x][i]+N];
for (int i=h[x]; i; i=s[i].nt) if (s[i].to!=fa) work(s[i].to,x);
ans[x]=c1[d[x]+w[x]]+c2[d[x]-w[x]+N]-cnt1-cnt2;
for (int i=0; i<d1[x].size(); ++i){
--c1[d1[x][i]];
if (d[x]+w[x]==d1[x][i]) --ans[x];
}
for (int i=0; i<d2[x].size(); ++i) --c2[d2[x][i]+N];
return;
}
int main(void){
read(n),read(m);
for (int i=1; i<n; ++i) read(x),read(y),add(x,y);
so1(1,0),so2(1,1);
for (int i=1; i<=n; ++i) read(w[i]);
for (int i=1; i<=m; ++i){
read(x),read(y),lca=query(x,y),len=d[x]+d[y]-(d[lca]<<1),++c[x];
a[y].push_back(d[y]-len),d1[lca].push_back(d[x]),d2[lca].push_back(d[y]-len);
}
work(1,0);
for (int i=1; i<n; ++i) printf("%d ",ans[i]);printf("%d",ans[n]);
return 0;
}