题目链接https://nanti.jisuanke.com/t/41388
题意
给出一棵树,每个点都有一个权值。Q次询问,每次询问一个点v和k,输出所有距离点v不超过k的点权值和。
题解
很精妙的想法。
先考虑另一个问题,查询子树v内距离点v不超过k的权值和f(v,k)。
在欧拉序上遍历,如果该点有询问f(v,k),在树状数组上查询比点v深度大于等于k的权值和,欧拉序会一个点被访问两次,后一次减去前一次就刚好是f(v,k)。
然后就用f(v,k)来计算本题就行了。f(v,k)+f(fa[v],k-1)-f(v,k-2)+f(fa[fa[v]],k-2)-f(fa[v],k-3)…,这样算出来就是答案。所以把每个询问拆成k个就行了。
还有就是写的不好看是会被卡掉的。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,ll> piir;
struct Node{
int v,k;
};
const int N=1e6+200;
struct Edge{
int v,nxt;
}e[N*2];
int p[N],edn,n;
void add(int u,int v){
e[++edn]=(Edge){v,p[u]};p[u]=edn;
e[++edn]=(Edge){u,p[v]};p[v]=edn;
}
vector<piir>vc[N];
ll val[N];
int vv[N],kk[N],fa[N],d[N];
void dfs(int u){
d[u]=d[fa[u]]+1;
for(int i=p[u];~i;i=e[i].nxt){
int v=e[i].v;
if(v==fa[u]) continue;
fa[v]=u;
dfs(v);
}
}
ll c[N];
int lb(int x){return x&(-x);}
void add(int x,ll y){
while(x<=n+100){
c[x]+=y;
x+=lb(x);
}
}
ll ask(int x){
ll res=0;
while(x){
res+=c[x];
x-=lb(x);
}
return res;
}
void solve(int u){
for(int i=0;i<vc[u].size();i++){
int k=vc[u][i].first;
vc[u][i].second-=ask(d[u]+k)-ask(d[u]-1);
}
add(d[u],val[u]);
for(int i=p[u];~i;i=e[i].nxt){
int v=e[i].v;
if(v==fa[u]) continue;
solve(v);
}
for(int i=0;i<vc[u].size();i++){
int k=vc[u][i].first;
vc[u][i].second+=ask(d[u]+k)-ask(d[u]-1);
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%lld",&val[i]);
p[i]=-1;
}
edn=-1;
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
add(u,v);
}
dfs(1);
int q;
scanf("%d",&q);
for(int i=1;i<=q;i++){
scanf("%d%d",&vv[i],&kk[i]);
vc[vv[i]].push_back(piir(kk[i],0));
int v=vv[i],k=kk[i];
while(fa[v]&&k){
vc[fa[v]].push_back(piir(k-1,0));
vc[v].push_back(piir(k-2,0));
v=fa[v];
k--;
}
}
solve(1);
for(int i=1;i<=n;i++){
sort(vc[i].begin(),vc[i].end());
}
for(int i=1;i<=q;i++){
ll res=0;
auto it=lower_bound(vc[vv[i]].begin(),vc[vv[i]].end(),piir(kk[i],0)); res+=it->second;
int v=vv[i],k=kk[i];
while(fa[v]&&k){
it=lower_bound(vc[fa[v]].begin(),vc[fa[v]].end(),piir(k-1,0)); res+=it->second;
it=lower_bound(vc[v].begin(),vc[v].end(),piir(k-2,0)); res-=it->second;
v=fa[v];
k--;
}
printf("%lld\n",res);
}
}