题目
n(n<=1e5)个点的树,给定一个参数k(k<=1e5),代表两点距离上限
点i上有一个权值
询问树上有序对(x,y)的数量,满足以下条件:
①x不等于y
②x不是y的祖先
③y不是x的祖先
④x和y的距离不超过k,即
⑤
思路来源
https://blog.csdn.net/qq_43202683/article/details/104108315
题解
启发式合并,然后考虑枚举到轻儿子的时候,答案如何统计,
轻儿子统计的应该是重儿子子树或者已经被插入的轻儿子子树上的答案,
设当前枚举到的lca为点u,轻儿子子树的点为x,
对于,其要统计的点(v,d)的数量
由于v是固定的,考虑对每一个v建一棵树,
由于均摊是n个点,所以考虑动态开点线段树/平衡树,支持询问该树内<=x的树有多少个,
这里采用动态开点线段树,插入时开一条x值对应的链,
询问时同线段树的区间询问,注意判一下v和d的合法范围,应该都在[0,n]之间,
统计完后,把轻儿子的答案加入
最后插入根节点lca,这样②③就能满足了
由于是有序对,最终答案*2
代码
#include<bits/stdc++.h>
using namespace std;
typedef pair<int,int> P;
typedef long long ll;
typedef double db;
#define fi first
#define se second
#define pb push_back
#define vi vector<int>
#define SZ(x) (int)(x.size())
#define sci(x) scanf("%d",&(x))
#define all(v) (v).begin(),(v).end()
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
ll modpow(ll x,ll n,ll mod){ll res=1;for(;n;n>>=1,x=x*x%mod)if(n&1)res=res*x%mod;return res;}
const db eps=1e-8,PI=acos(-1.0);
const int N=1e5+10,INF=0x3f3f3f3f,mod=1e9+7;//998244353
vector<int>E[N];
ll res;
int rt[N],ls[N*200],rs[N*200],num[N*200],c;
void upd(int &p,int l,int r,int x,int v){
if(!p)p=++c;
num[p]+=v;
if(l==r){
return;
}
int mid=(l+r)/2;
if(x<=mid)upd(ls[p],l,mid,x,v);
else upd(rs[p],mid+1,r,x,v);
}
int ask(int p,int l,int r,int ql,int qr){
if(!p)return 0;//注意 树可能没有这个区间对应的点
if(ql>qr)return 0;
if(ql<=l && r<=qr){
return num[p];
}
int ans=0,mid=(l+r)/2;
if(ql<=mid)ans+=ask(ls[p],l,mid,ql,qr);
if(qr>mid)ans+=ask(rs[p],mid+1,r,ql,qr);
return ans;
}
int n,k,w[N],f;
int d[N];
int sz[N];
int st[N],ed[N],dfn[N],tot;
void dfs(int u,int fa){
sz[u]=1;
st[u]=++tot;
dfn[tot]=u;
for(int v:E[u]){
if(v!=fa){
d[v]=d[u]+1;
dfs(v,u);
sz[u]+=sz[v];
}
}
ed[u]=tot;
}
void dfs(int u,int fa,bool keep){
int mx=-1,son=-1;
for(int v:E[u]){
if(v!=fa&&sz[v]>mx)
mx=sz[v],son=v;
}
for(int v:E[u]){
if(v!=fa&&v!=son){
dfs(v,u,0);
}
}
if(son!=-1){
dfs(son,u,1);
}
for(int v:E[u]){
if(v!=fa&&v!=son){
for(int i=st[v];i<=ed[v];i++){
int x=dfn[i],oth=2*w[u]-w[x];
if(oth>=0 && oth<=n){
res+=ask(rt[oth],0,n,0,min(n,2*d[u]+k-d[x]));
}
}
for(int i=st[v];i<=ed[v];i++){
int x=dfn[i];
upd(rt[w[x]],0,n,d[x],1);
}
}
}
upd(rt[w[u]],0,n,d[u],1);
if(keep==0){
for(int i=st[u];i<=ed[u];i++){
int x=dfn[i];
upd(rt[w[x]],0,n,d[x],-1);
}
}
}
int main(){
scanf("%d%d",&n,&k);
for(int i=1;i<=n;++i){
scanf("%d",&w[i]);
}
for(int i=2;i<=n;++i){
scanf("%d",&f);
E[f].pb(i);//E[v].pb(u);
}
dfs(1,-1);
dfs(1,-1,0);
printf("%lld\n",2ll*res);
return 0;
}