题目大意:给你一棵树和一堆询问,每个询问是问从x到y的路径,从x开始每k个节点就加一次当前点权值,最后一步如果小于k就直接加上y的权值,问这条路径总权值是多少
看一眼数据范围猜可能是分块,然后就不会了.....
于是就去膜拜了Claris,果然是分块啊,不过题解好简短,没看懂
只能和xuruifan研究了半天,大概是这么一个思想
把询问按大于根号n和小于根号n分成两类
对于小于根号n的:
我们可以先预处理出来对于每个节点以步伐为i往上跳一次的位置是哪里,然后就可以得出对于每个节点以步伐为i一直向上跳,跳到不能跳为止的权值和是多少
然后对于每个询问,我们可以用两个节点一直向上跳的权值和减去多算了的那些就可以了
预处理时间复杂度O(N $\sqrt{N}$ ),询问单次时间复杂度O(logn)
对于大于根号n的:
先树链剖分
然后对于每个询问暴力查询,如果当前节点和重链顶端的距离超过d,那就O(1)利用剖完的顺序直接跳上去。否则就可以算一下,然后O(跨过的轻链重链个数)跳上去
因为一个节点到根的重链轻链总个数不会超过2logN,所以询问单次时间复杂度是O(logN+N/k),因为k大于 $\sqrt{N}$ ,所以单次询问时间复杂度是O( $\sqrt{N}$ )
恩,写的挺详细了吧......
上一份写得很长的代码吧....
#include<iostream>
#include<cstdio>
#include<cmath>
#define N 50010
using namespace std;
int a[N];
int to[N<<1],nxt[N<<1],pre[N],cnt;
void ae(int ff,int tt)
{
cnt++;
to[cnt]=tt;
nxt[cnt]=pre[ff];
pre[ff]=cnt;
}
int fa[N][120],d[N],siz[N],zs[N];
void build(int x)
{
int maxn=0,maxb=0;
int i,j;
siz[x]=1;d[x]=d[fa[x][1]]+1;
for(i=pre[x];i;i=nxt[i])
{
j=to[i];
if(j==fa[x][1]) continue;
fa[j][1]=x;
build(j);
siz[x]+=siz[j];
if(siz[j]>maxn)
{
maxn=siz[j];
maxb=j;
}
}
zs[x]=maxb;
}
int jut[N][120],top[N],sit[N],fan[N],cn;
int n,m;
void dfs(int x,int tt)
{
cn++;sit[x]=cn;fan[cn]=x;
top[x]=tt;
int i,j;
for(i=1;i<=m;i++)
jut[x][i]=jut[fa[x][i]][i]+a[x];
if(zs[x]) dfs(zs[x],tt);
for(i=pre[x];i;i=nxt[i])
{
j=to[i];
if(j==fa[x][1]||j==zs[x]) continue;
dfs(j,j);
}
}
int LCA(int x,int y)
{
while(top[x]!=top[y])
{
if(d[top[x]]<d[top[y]]) swap(x,y);
x=fa[top[x]][1];
}
if(d[x]<d[y]) swap(x,y);
return y;
}
int jump(int x,int k)
{
int y=top[x];
while(k>=d[x]-d[y]+1&&x)
{
k-=d[x]-d[y]+1;
x=fa[y][1];y=top[x];
}
if(!x) return 0;
return fan[sit[x]-k];
}
int solve1(int x,int y,int k)
{
int z=LCA(x,y),L=d[x]+d[y]-2*d[z],ret=0;
if(L%k!=0)
{
ret+=a[y];
y=jump(y,L%k);
}
L=d[x]-d[z];
int t;
if(L%k==0)
{
t=jump(z,k);
ret+=jut[x][k]+jut[y][k]-jut[z][k]-jut[t][k];
}
else
{
t=jump(z,k-L%k);
ret+=jut[x][k]-jut[t][k];
L=d[y]-d[z];
t=jump(z,k-L%k);
ret+=jut[y][k]-jut[t][k];
}
return ret;
}
int cal(int x,int k)
{
int ret=0,y;
while(x)
{
ret+=a[x];
y=top[x];
if(d[x]-d[y]>=k) x=fan[sit[x]-k];
else x=jump(x,k);
}
return ret;
}
int solve2(int x,int y,int k)
{
int z=LCA(x,y),L=d[x]+d[y]-2*d[z],ret=0;
if(L%k!=0)
{
ret+=a[y];
y=jump(y,L%k);
}
L=d[x]-d[z];
int t;
if(L%k==0)
{
t=jump(z,k);
ret+=cal(x,k)+cal(y,k)-cal(z,k)-cal(t,k);
}
else
{
t=jump(z,k-L%k);
ret+=cal(x,k)-cal(t,k);
L=d[y]-d[z];
t=jump(z,k-L%k);
ret+=cal(y,k)-cal(t,k);
}
return ret;
}
int b[N],c[N];
int main()
{
scanf("%d",&n);
m=sqrt(n)/2;
int i,j,x,y;
for(i=1;i<=n;i++)
scanf("%d",&a[i]);
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
ae(x,y);ae(y,x);
}
build(1);
for(i=1;i<=n;i++)
for(j=2;j<=m;j++)
fa[i][j]=fa[fa[i][j-1]][1];
dfs(1,1);
for(i=1;i<=n;i++)
scanf("%d",&b[i]);
for(i=1;i<n;i++)
{
scanf("%d",&x);
if(x<=m) printf("%d\n",solve1(b[i],b[i+1],x));
else printf("%d\n",solve2(b[i],b[i+1],x));
}
}