http://www.lydsy.com/JudgeOnline/problem.php?id=4381
若步长小于sqrt(n)则可以预处理每个点走某种步长走到跟的权值和然后减去LCA上面的部分;若步长大于sqrt(n)则暴力走,为了避免LCA算重,可以先防止两个点走到LCA,然后再特判能否走到LCA上。第一种情况要注意不要计算走过头的点。
用长链剖分进行预处理就可以o(1)查询某个点的K级祖先。
#include<cstdio>
#include<cmath>
#define gm 50001
using namespace std;
int n,size;
inline int* alloc(size_t cnt)
{
static int pool[gm<<2],*ptr=pool;
int *res=ptr; ptr+=cnt;
return res;
}
int hi[gm],ord[gm];
int fa[gm][16];
int pre[gm][250];
int a[gm],b[gm],c[gm],dep[gm],top[gm],bot[gm],*up[gm],*down[gm];
int maxd;
inline int jump(int x,int k)
{
if(!k) return x; if(k>dep[x]) return 0;
x=fa[x][hi[k]]; k-=1<<hi[k];
int dt=dep[x]-dep[top[x]];
return k>dt?up[top[x]][k-dt]:down[top[x]][dt-k];
}
inline int LCA(int a,int b)
{
if(dep[a]>dep[b]) a=jump(a,dep[a]-dep[b]);
else b=jump(b,dep[b]-dep[a]);
if(a==b) return a;
for(int i=hi[maxd];~i;--i)
if(fa[a][i]!=fa[b][i]) a=fa[a][i],b=fa[b][i];
return fa[a][0];
}
struct e
{
int t;
e *n;
e(int t,e *n):t(t),n(n){}
}*f[gm];
int ct=0;
void dfs(int x)
{
bot[x]=ord[++ct]=x;
for(e *i=f[x];i;i=i->n)
{
if(*fa[x]==i->t) continue;
*fa[i->t]=x;
dep[i->t]=dep[x]+1;
dfs(i->t);
if(dep[bot[i->t]]>dep[bot[x]]) bot[x]=bot[i->t];
}
for(e *i=f[x];i;i=i->n)
{
if(*fa[x]==i->t||bot[x]==bot[i->t]) continue;
int y=i->t,len=dep[bot[y]]-dep[y]+1;
up[y]=alloc(len); down[y]=alloc(len);
int kre=len;
for(int j=bot[y];j!=*fa[y];j=*fa[j]) top[j]=y,down[y][--kre]=j;
kre=0;
for(int j=y;j&&kre<len;j=*fa[j]) up[y][kre++]=j;
}
if(x==1)
{
maxd=dep[bot[1]]-dep[1]; size=sqrt(maxd);
int len=maxd+1;
up[1]=alloc(len); down[1]=alloc(len);
int kre=len;
for(int j=bot[1];j;j=*fa[j]) top[j]=1,down[1][--kre]=j;
for(int i=1;(1<<i)<=maxd;++i)
for(int j=1;j<=n;++j)
fa[j][i]=fa[fa[j][i-1]][i-1];
for(int i=2;i<=n;++i) hi[i]=hi[i>>1]+1;
for(int i=1;i<=n;++i)
{
x=ord[i];
for(int i=1;i<=size;++i)
pre[x][i]=a[x]+pre[jump(x,i)][i];
}
}
}
int walk(int x,int y,int c)
{
int res=0,lca=LCA(x,y),lex=dep[x]-dep[lca],ley=dep[y]-dep[lca];
if(c>size)
{
res=a[x]+a[y];
while(lex>c)
{
x=jump(x,c);
res+=a[x];
lex-=c;
}
if(lex+ley-1>=c)
{
y=jump(y,ley-(lex+ley-1)/c*c+lex);
res+=a[y]; ley=(lex+ley-1)/c*c-lex;
while(ley>c)
{
y=jump(y,c);
res+=a[y];
ley-=c;
}
}
if(x!=lca&&y!=lca&&(lex==c||ley==c)) res+=a[lca];
}
else
{
if(lex+ley-1<c) res=a[x]+a[y];
else
{
res=a[y]; y=jump(y,ley-(lex+ley-1)/c*c+lex); ley=(lex+ley-1)/c*c-lex;
res+=pre[x][c]-pre[jump(x,lex/c*c+c)][c];
if(ley>=0) res+=pre[y][c]-pre[jump(y,ley/c*c+c)][c];
if(lex%c==0) res-=a[lca];
}
}
return res;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;++i) scanf("%d",a+i);
for(int i=1;i<n;++i)
{
int u,v; scanf("%d%d",&u,&v);
f[u]=new e(v,f[u]); f[v]=new e(u,f[v]);
}
dfs(1);
for(int i=1;i<=n;++i) scanf("%d",b+i);
for(int i=1;i<n;++i)
{
scanf("%d",c+i);
printf("%d\n",walk(b[i],b[i+1],c[i]));
}
return 0;
}