题目:
2019-2020 ICPC Asia Nanchang Regional Onsite Contest
题意:
给定一颗树,求有多少个点对(x,y)满足:(1)二者不在一条链上(2)val[x]+val[y] = 2*val[lca(x,y)](3)二者之间的距离 <= k
分析:
dfs 依次枚举 lca,x,y属于不同的儿子子树内,遍历每个儿子子树统计答案;对每个 val 开一颗平衡树维护子树内每个点到 lca 的距离,树上启发式合并即可,复杂度O(NlogNlogN)
代码:
#include <bits/stdc++.h>
#define x first
#define y second
#define pii pair<int,int>
#define sz(x) (int)(x).size()
#define Max(x,y) (x)>(y)?(x):(y)
#define Min(x,y) (x)<(y)?(x):(y)
#define all(x) (x).begin(),(x).end()
using namespace std;
typedef long long LL;
const int maxm = 2e6+26;
const int maxn = 1e5+15;
struct data{int l,r,v,size,fix,w;}tr[maxm];
int len,root[maxn];
void update(int k){tr[k].size=tr[tr[k].l].size+tr[tr[k].r].size+tr[k].w;}
void rturn(int &k){int t=tr[k].l;tr[k].l=tr[t].r;tr[t].r=k;tr[t].size=tr[k].size;update(k);k=t;}
void lturn(int &k){int t=tr[k].r;tr[k].r=tr[t].l;tr[t].l=k;tr[t].size=tr[k].size;update(k);k=t;}
void Insert(int &k,int x) //插入数x
{
if(k==0){len++;k=len;tr[k].size=tr[k].w=1;tr[k].v=x;tr[k].fix=rand();return;}
tr[k].size++;
if(tr[k].v==x)tr[k].w++;
else if(x>tr[k].v){Insert(tr[k].r,x);if(tr[tr[k].r].fix<tr[k].fix)lturn(k);}
else {Insert(tr[k].l,x);if(tr[tr[k].l].fix<tr[k].fix)rturn(k);}
}
void Delete(int &k,int x) //删除数x
{
if(k==0) return;
if(tr[k].v==x)
{
if(tr[k].w>1){tr[k].w--;tr[k].size--;return;}
if(tr[k].l*tr[k].r==0)k=tr[k].l+tr[k].r;
else if(tr[tr[k].l].fix<tr[tr[k].r].fix) rturn(k),Delete(k,x);
else lturn(k),Delete(k,x);
}
else if(x>tr[k].v) tr[k].size--,Delete(tr[k].r,x);
else tr[k].size--,Delete(tr[k].l,x);
}
/*
查找x的排名
返回 <= x 的数量,等于x的数有多个,只算一个
*/
int Findrank(int k,int x)
{
if(k==0) return 0;
if(tr[k].v==x) return tr[tr[k].l].size+1;
else if(x>tr[k].v) return tr[tr[k].l].size+tr[k].w+Findrank(tr[k].r,x);
else return Findrank(tr[k].l,x);
}
int Findkth(int k,int x) //查找排名为x的数
{
if(k==0) return 0;
if(x<=tr[tr[k].l].size) return Findkth(tr[k].l,x);
else if(x>tr[tr[k].l].size+tr[k].w) return Findkth(tr[k].r,x-tr[tr[k].l].size-tr[k].w);
else return tr[k].v;
}
/**************以上为平衡树模板*****************/
int n,k,u,val[maxn],head[maxn],tot;
struct edge{
int to,nxt;
}e[maxn];
inline void add(int u,int v){
e[++tot] = (edge){v,head[u]};
head[u] = tot;
}
int sz[maxn],son[maxn],p[maxn];
void dfs(int x){
sz[x] = 1;
for(int i = head[x];i > 0;i=e[i].nxt){
int v = e[i].to;
dfs(v); sz[x] += sz[v];
if(sz[v] > sz[son[x]]) son[x] = v;
}
}
LL ans;
void calAns(int x,int z,int deep){ //统计答案
int vy = 2*val[z]-val[x];
if(vy >= 0 && vy <= n){
int tep = Findrank(root[vy],k-deep-p[z]+1);
if(tep && Findkth(root[vy],tep)>k-deep-p[z]) tep--;
ans += tep;
}
for(int i = head[x];i > 0;i=e[i].nxt){
int v = e[i].to;
calAns(v,z,deep+1);
}
}
void addPoint(int x,int z,int deep){
Insert(root[val[x]],deep-p[z]);
for(int i = head[x];i > 0;i=e[i].nxt){
int v = e[i].to;
addPoint(v,z,deep+1);
}
}
void delPoint(int x,int z,int deep){
Delete(root[val[x]],deep-p[z]);
for(int i = head[x];i > 0;i=e[i].nxt){
int v = e[i].to;
delPoint(v,z,deep+1);
}
}
void dsu(int x,int flag){
for(int i = head[x];i > 0;i=e[i].nxt){
int v = e[i].to;
if(v != son[x]) dsu(v,0);
}
if(son[x]) dsu(son[x],1),p[x] = p[son[x]]+1;
for(int i = head[x];i > 0;i=e[i].nxt){
int v = e[i].to;
if(v != son[x]){
calAns(v,x,1);
addPoint(v,x,1);
}
}
Insert(root[val[x]],0-p[x]);
if(!flag) delPoint(x,x,0); //清除整颗非重儿子子树的影响
}
int main(){
scanf("%d %d",&n,&k);
for(int i = 1;i <= n; ++i) scanf("%d",val+i);
for(int i = 2;i <= n; ++i){
scanf("%d",&u); add(u,i);
}
dfs(1); dsu(1,1);
printf("%lld\n",ans*2);
return 0;
}