传送门
题目要求求出符合以下条件的树上点对个数:
1.两点间不能是祖先和子孙的关系
2.x点和y点的权值相加为2倍的他们的lca的权值
3.x点和y点间的距离小于等于给定的k
对于第一个条件,我们使用dsu on tree,每次枚举一个点,计算他为lca的答案即可。注意计算贡献的时候,要先计算一个子树的贡献,再将这个子树上的值加到我们使用的线段树上去,因为这样计算的才是其中一个点在一棵子树上,另一个点在另一颗子树上的答案,这样他们的路径是经过当前枚举的lca的。
对于第二个条件,我们使用动态开点线段树,对每一个权值开一棵线段树,这样可以控制“权值恰好等于”这个精确条件。由于空间可能不够所以使用动态开点。
对于第三个条件,我们在每个权值的线段树上,维护的是每个深度有几个节点,因此使用的是权值线段树,这样可以控制“距离小于等于”这个范围条件。两点间的距离是dep[x] + dep[y] - 2 * dep[lca],我们假设题意中符合条件的点的权值为val,深度要求小于等于k,那么我们就直接在rt[val]这颗线段树上查询0~k的值的个数有几个就行了。
于是这题就算完了,注意由于是点对,所以答案需要X2
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <math.h>
#include <map>
#include <set>
#include <queue>
using namespace std;
#define endl '\n'
const int maxn = 1e5 + 5;
const int maxm = maxn << 1;
const int maxx = maxn * 32;
int head[maxn],nex[maxm],v[maxm],w[maxm],cnt,dep[maxn];
void add(int x,int y){
nex[++cnt] = head[x];
head[x] = cnt;
v[cnt] = y;
}
int rt[maxn],val[maxx],ls[maxx],rs[maxx],idx,son[maxn],sz[maxn];
int n,k,a[maxn],skip;
long long res;
void update(int l,int r,int &o,int p,int x){
if (!o) o = ++idx;
if (l == r){
val[o] += x;
return;
}
int mid = l + r >> 1;
if (p <= mid) update(l,mid,ls[o],p,x);
else update(mid + 1,r,rs[o],p,x);
val[o] = val[ls[o]] + val[rs[o]];
}
int query(int l,int r,int o,int x,int y){
if (!o) return 0;
if (x <= l && r <= y) return val[o];
int mid = l + r >> 1;
int res = 0;
if (x <= mid) res += query(l,mid,ls[o],x,y);
if (y > mid) res += query(mid + 1,r,rs[o],x,y);
return res;
}
void dfs1(int node,int fa){
sz[node] = 1;
dep[node] = dep[fa] + 1;
for (int i = head[node]; i; i = nex[i]) {
int to = v[i];
if (to == fa) continue;
dfs1(to,node);
sz[node] += sz[to];
if (sz[to] > sz[son[node]]) son[node] = to;
}
}
void precal(int node,int fa,int x){
update(0,n,rt[a[node]],dep[node],x);
for (int i = head[node]; i; i = nex[i]) {
int to = v[i];
if (to == fa) continue;
precal(to,node,x);
}
}
void cal(int node,int fa,int lca){
int vall = 2 * a[lca] - a[node];
int dp = k + 2 * dep[lca] - dep[node];
dp = min(dp,n);
if (vall >= 0 && vall <= n && dp >= 0 && dp <= n) res += 2 * query(0,n,rt[vall],0,dp);
for (int i = head[node]; i; i = nex[i]) {
int to = v[i];
if (to == fa) continue;
cal(to,node,lca);
}
}
void dfs(int node,int fa,bool clear){
for (int i = head[node]; i; i = nex[i]) {
int to = v[i];
if (to == fa || to == son[node]) continue;
dfs(to,node,true);
// precal(to,node,-1);
}
if (son[node]) dfs(son[node],node,false);
for (int i = head[node]; i; i = nex[i]) {
int to = v[i];
if (to == fa || to == son[node]) continue;
cal(to,node,node);
precal(to,node,1);
}
update(0,n,rt[a[node]],dep[node],1);
if (clear) precal(node,fa,-1);
}
void solve(){
cin >> n >> k;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
for (int i = 2; i <= n; ++i) {
int x;
cin >> x;
add(x,i);
add(i,x);
}
dfs1(1,0);
dfs(1,0, true);
cout << res << endl;
}
signed main(){
std::ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
#ifdef LOCAL
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
int T = 1;
// cin >> T;
while (T--) solve();
}