大概就是步数小的预处理前缀和
步数大的直接倍增跳
各种细节搞得欲仙欲死
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
typedef long long ll;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
if (p1==p2) { p2=(p1=buf)+fread(buf,1,100000,stdin); if (p1==p2) return EOF; }
return *p1++;
}
inline void read(int &x){
char c=nc(),b=1;
for (;!(c>='0' && c<='9');c=nc()) if (c=='-') b=-1;
for (x=0;c>='0' && c<='9';x=x*10+c-'0',c=nc()); x*=b;
}
const int N=50005;
const int K=21;
const int M=305;
struct edge{
int u,v,next;
}G[N<<1];
int head[N],inum;
inline void add(int u,int v,int p){
G[p].u=u; G[p].v=v; G[p].next=head[u]; head[u]=p;
}
int n,B; ll ans;
int val[N],bp[N];
int depth[N],fat[N][K];
ll S[N][M];
#define V G[p].v
inline void dfs(int u,int fa){
depth[u]=depth[fa]+1; fat[u][0]=fa;
for (int k=1;k<K;k++)
fat[u][k]=fat[fat[u][k-1]][k-1];
int f=fa;
for (int i=1;i<=B;i++)
S[u][i]=S[f][i]+val[u],f=fat[f][0];
for (int p=head[u];p;p=G[p].next)
if (V!=fa)
dfs(V,u);
}
inline int LCA(int u,int v){
if (depth[u]<depth[v]) swap(u,v);
for (int k=K-1;~k;k--)
if ((depth[u]-depth[v])>>k&1)
u=fat[u][k];
if (u==v) return u;
for (int k=K-1;~k;k--)
if (fat[u][k]!=fat[v][k])
u=fat[u][k],v=fat[v][k];
return fat[u][0];
}
inline int Fat(int u,int s){
for (int k=K-1;~k;k--)
if (s>>k&1)
u=fat[u][k];
return u;
}
inline ll Solve(int u,int v,int cp){
int lca=LCA(u,v),len=depth[u]+depth[v]-2*depth[lca],t,l1,l2;
ll ret=0;
if (len<=cp) return val[u]+val[v];
if (cp>B){
ret+=val[u];
while (1){
t=Fat(u,cp);
if (depth[t]>=depth[lca])
ret+=val[t],u=t;
else
break;
}
if (v!=lca){
ret+=val[v];
if (len%cp){
v=Fat(v,len%cp);
if (depth[v]>depth[lca])
ret+=val[v];
}
while (1){
t=Fat(v,cp);
if (depth[t]>depth[lca])
ret+=val[t],v=t;
else
break;
}
return ret;
}else{
if (len%cp) ret+=val[v];
return ret;
}
}
if (u==lca)
ret+=val[lca];
else{
l1=((depth[u]-depth[lca])/cp+1)*cp;
t=Fat(u,l1);
ret+=S[u][cp]-S[t][cp];
}
if (v!=lca){
if (len%cp)
ret+=val[v],v=Fat(v,len%cp);
if (depth[v]<=depth[lca]) return ret;
l2=((depth[v]-depth[lca]-1)/cp+1)*cp;
t=Fat(v,l2);
ret+=S[v][cp]-S[t][cp];
}
else{
if (len%cp)
ret+=val[v];
}
return ret;
}
int main(){
int iu,iv,ic;
freopen("t.in","r",stdin);
freopen("t.out","w",stdout);
read(n); B=min(n,(int)(sqrt(n)/log(n)+1));
for (int i=1;i<=n;i++)
read(val[i]);
for (int i=1;i<n;i++)
read(iu),read(iv),add(iu,iv,++inum),add(iv,iu,++inum);
dfs(1,0);
for (int i=1;i<=n;i++)
read(bp[i]);
for (int i=1;i<n;i++)
read(ic),printf("%lld\n",Solve(bp[i],bp[i+1],ic));
return 0;
}