题目大意:
就是给你一颗树,每个点有个权值 v i v_i vi,问你有多少对 ( x , y ) (x,y) (x,y)满足:
- x x x不是 y y y的祖先
- y y y也不是 x x x的祖先
- x x x和 y y y的距离不超过 k k k
- x x x和 y y y最近公共祖先: z z z,满足 v x + v y = 2 v z v_x+v_y=2v_z vx+vy=2vz
解题思路:
首先我一开始以为是点分治,但是发现点分之后公共祖先维护不了!!
那么我们树的结构肯定变不了,那么我们可以固定公共祖先 z z z,那么就是在 z z z这个子树里面,找到两个子树里面点 x , y x,y x,y, x x x和 y y y的距离不超过 k k k,且满足 v x + v y = 2 v z v_x+v_y=2v_z vx+vy=2vz
那么统计子树问题肯定树上启发式合并啦!!
好久没写了现在来回归一下树上启发式合并的步骤
- 就是先跑轻儿子,清空轻儿子信息。
- 然后再跑重儿子
- 回溯的时候再暴力遍历轻儿子
遍历时间是 O ( n l o g n ) O(nlogn) O(nlogn)
为什么呢?我们知道树链剖分上面重链是 log \text{log} log条的,那么轻边也就 log \text{log} log条,那么就对于每个点的信息它最多被撤销 log \text{log} log次
那么我们看看限制是二维的我们把距离调成到根节点的距离
d
i
:
是
i
点
到
根
节
点
的
距
离
d_i:是i点到根节点的距离
di:是i点到根节点的距离那么就有
d
x
+
d
y
−
2
×
d
z
≤
k
d_x+d_y-2\times d_z\leq k
dx+dy−2×dz≤k
那另一个限制是
v
x
+
v
y
−
2
×
v
z
=
0
v_x+v_y-2\times v_z=0
vx+vy−2×vz=0
我们已经知道
z
z
z,然后遍历
x
x
x, 那么就是我们知道了
x
,
z
x,z
x,z的y求有多少个
y
y
y满足条件
d
y
≤
k
+
2
×
d
z
−
d
x
d_y\leq k+2\times d_z-d_x
dy≤k+2×dz−dx
v
y
=
2
×
v
z
−
v
x
v_y=2\times v_z-v_x
vy=2×vz−vx
那么我们就看前面子树所有等于
v
y
v_y
vy里面看有多少
d
y
≤
k
+
2
×
d
z
−
d
x
d_y\leq k+2\times d_z-d_x
dy≤k+2×dz−dx
那就是个二维限制,树套树?权值太大了!!,我们可以对每个
v
y
v_y
vy开个动态开点权值线段树,叶子节点是
d
d
d。然后我们每次就是查询以
r
o
o
t
[
v
y
]
root[v_y]
root[vy]根的子树里面查询
[
1
,
k
+
2
×
d
z
−
d
x
]
[1,k+2\times d_z-d_x]
[1,k+2×dz−dx] 有多少个点.
AC code
#include <bits/stdc++.h>
using namespace std;
#define mid ((l+r)>>1)
const int maxn = 8e5 + 10;
const int len = 2e5 + 10;
typedef long long ll;
typedef pair<int,int> PII;
int idx;
int root[maxn];
struct node {
int lson, rson, num;
}tr[maxn<<2];
inline void pushup(int rt) {
tr[rt].num = tr[tr[rt].lson].num + tr[tr[rt].rson].num;
}
inline void insert(int &rt, int l, int r, int pos, int val) {
if(!rt) rt = ++ idx;
if(l == r) {
tr[rt].num += val;
return;
}
if(pos <= mid) insert(tr[rt].lson,l,mid,pos,val);
else insert(tr[rt].rson,mid+1,r,pos,val);
pushup(rt);
}
inline int ask(int rt, int l, int r, int posl, int posr) {
if(!rt) return 0;
if(posl <= l && posr >= r) return tr[rt].num;
int res = 0;
if(posl <= mid) res += ask(tr[rt].lson,l,mid,posl,posr);
if(posr > mid) res += ask(tr[rt].rson,mid+1,r,posl,posr);
return res;
}
//...........................................
int n, k;
vector<int> G[maxn];
int node[maxn];
int depth[maxn], siz[maxn], son[maxn];
ll ans = 0;
inline void find_son(int u, int fa) {
siz[u] = 1;
depth[u] = depth[fa] + 1;
for(auto it : G[u]) {
find_son(it,u);
siz[u] += siz[it];
if(son[u] == 0 || siz[son[u]] < siz[it]) son[u] = it;
}
}
vector<PII> now; // now里面保存的是权值线段树里面插的点
int dp, vp; // p就是z来着
inline void Count(int u) {
int vy = 2*vp - node[u];
int lim = k + 2 * dp - depth[u];
if(vy >= 0 && lim > 0 && vy <= n && lim > dp) ans += ask(root[vy],1,len,1,lim); // 注意合法性
now.push_back({depth[u],node[u]});
for(auto it : G[u]) Count(it);
}
inline void dfs(int u, int keep) {
for(auto it : G[u]) {
if(it == son[u]) continue;
dfs(it,0);
}
if(son[u]) dfs(son[u],1);
int last = now.size();
dp = depth[u];
vp = node[u];
for(auto it : G[u]) {
if(it == son[u]) continue;
Count(it);
for(int i = last; i < now.size(); ++ i)
insert(root[now[i].second],1,len,now[i].first,1);
last = now.size();
}
if(!keep) {
for(int i = 0; i < now.size(); ++ i)
insert(root[now[i].second],1,len,now[i].first,-1);
now.clear();
} else insert(root[node[u]],1,len,depth[u],1), now.push_back({depth[u],node[u]});
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n >> k;
for(int i = 1; i <= n; ++ i) cin >> node[i];
for(int i = 2; i <= n; ++ i) {
int x;
cin >> x;
G[x].push_back(i);
}
depth[0] = -1;
find_son(1,0);
dfs(1,1);
cout << 2ll * ans;
return 0;
}
/*
7 2
1 2 3 4 5 6 7
1 2 2 1 5 5
*/