Little M’s attack plan
题意:给一颗带有权值的树,然后q个询问,每个询问x和k,回答与点x的距离在k以内的所有点的权值和。
思路:容斥原理,以下是对于每个询问的回答,dp[i][j]代表点i的子树中的所有点距离点i在j个单位距离的点的权值和。
一开始开了个dp[MAX_N][110]的数组,数组太大而且会T。
scanf("%d%d",&x,&k);
long long ans=0;
int pre=x;
for(j=k;j>=0;j--){
ans+=dp[pre][j];
if(pre==1)
break;
if(j>=2)
ans-=dp[pre][j-2];
pre=fa[pre];
}
printf("%lld\n",ans);
然后因为q询问一共5000个,可以离线看哪些dp[i][j]可以用到,然后跑dfs,树状数组存储遍历到的所有点某个深度的权值和,跑到一个点,如果这个点有询问,就进行查询当前点的深度deep[x]到deep[x]+k的所有点的权值和,然后当跑完子树再回溯到这个点时再查询一遍,这时候那这个值减去之前的值就是这个点的子树上的答案,即dp[x][k]。
如果是用vector,注意哪些dp[i][j]已经被放在vector里面了,防止重复计算,因为容易出错。去重!!!
#include<iostream>
#include<cstdio>
#include<map>
#include<vector>
using namespace std;
const int MAX_N=1010000;
map<int,long long>ma[MAX_N];
map<int,bool>mb[MAX_N];
int head[MAX_N],ver[2*MAX_N],Next[2*MAX_N];
int tot;
vector<int>v[MAX_N];
void Add(int x,int y){
ver[++tot]=y;Next[tot]=head[x];head[x]=tot;
}
long long sum[MAX_N],a[MAX_N];
int nn;
int deep[MAX_N];
int fa[MAX_N];
struct skt{
int x,k;
}b[5010];
void add(int p,long long x){
while(p<=nn){
sum[p]+=x;
p+=p&-p;
}
}
long long ask(int p){
long long ans=0;
while(p){
ans+=sum[p];
p-=p&-p;
}
return ans;
}
void dfs1(int x){
for(int i=head[x];i;i=Next[i]){
int y=ver[i];
if(y==fa[x])
continue;
fa[y]=x;
deep[y]=deep[x]+1;
nn=max(nn,deep[y]);
dfs1(y);
}
}
void dfs2(int x){
int i;
for(i=0;i<v[x].size();i++){
int k=v[x][i];
ma[x][k]=ask(deep[x]+k)-ask(deep[x]-1);
}
add(deep[x],a[x]);
for(i=head[x];i;i=Next[i]){
int y=ver[i];
if(y==fa[x])
continue;
dfs2(y);
}
for(i=0;i<v[x].size();i++){
int k=v[x][i];
ma[x][k]=ask(deep[x]+k)-ask(deep[x]-1)-ma[x][k];
}
}
int main(void){
int n,i,j,x,y,q;
scanf("%d",&n);
for(i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
for(i=1;i<n;i++){
scanf("%d%d",&x,&y);
Add(x,y);
Add(y,x);
}
deep[1]=1;
dfs1(1);
nn+=110;
scanf("%d",&q);
for(i=0;i<q;i++){
scanf("%d%d",&b[i].x,&b[i].k);
int pre=b[i].x;
for(j=b[i].k;j>=0;j--){
if(!mb[pre][j]){
v[pre].push_back(j);
mb[pre][j]=true;
}
if(pre==1)
break;
if(j>=2){
if(!mb[pre][j-2]){
v[pre].push_back(j-2);
mb[pre][j-2]=true;
}
}
pre=fa[pre];
}
}
dfs2(1);
for(i=0;i<q;i++){
long long ans=0;
int pre=b[i].x;
for(j=b[i].k;j>=0;j--){
ans+=ma[pre][j];
if(pre==1)
break;
if(j>=2)
ans-=ma[pre][j-2];
pre=fa[pre];
}
printf("%lld\n",ans);
}
return 0;
}