原题链接:https://nanti.jisuanke.com/t/42586
题意
有一棵树,每个点都有权值vi,求满足以下三个条件的点对个数
- 这两个点不能互为祖先节点
- 这两个点之间的最短距离不能超过k
- 这两个点的最近公共祖先节点的权值*2等于这两个点权值之和
分析
看到树上统计点对的个数,不难想到树上启发式合并和点分治,接着又两个约束条件,一个是树上路径,还有一个是权值约束,树上路径根据套路可以转化为深度,如果已知dep[u]和路径长度k求dep[v],那么
d
e
p
[
v
]
=
2
∗
d
e
p
[
l
c
a
(
u
,
v
)
]
+
k
−
d
e
p
[
u
]
dep[v] = 2 * dep[lca(u,v)] + k - dep[u]
dep[v]=2∗dep[lca(u,v)]+k−dep[u]
由于可行的深度是在一定的范围,因此暴力枚举肯定是不行的,这时就可以借助线段树实现logn的查询区间和,但如果直接开n颗线段树肯定会mle,因此考虑动态开点,我们建一颗权值线段树代表当前权值的点在某个深度的个数,最后在启发式合并的同时统计就可以了,时间复杂度在n(logn)^2
Code
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#include <cmath>
#include <bitset>
#include <map>
#include <set>
#include <stack>
#include <queue>
//#include <unordered_map>
using namespace std;
#define fi first
#define se second
#define re register
typedef long long ll;
typedef pair<ll, ll> PII;
typedef unsigned long long ull;
const int N = 1e5 + 10, M = 1e6 + 5, INF = 0x3f3f3f3f;
const int MOD = 1e9+7;
int n, K;
int a[N], sz[N], son[N], vis[N], dep[N], dfn[N], rnk[N], tot, h[N];
ll ans;
int rt[N], sum[N<<5], lc[N<<5], rc[N<<5], cnt, cntnode;
void push_up(int now) {
sum[now] = sum[lc[now]] + sum[rc[now]];
}
void modify(int &now, int l, int r, int pos, int val) {
if (!now) now = ++cntnode;
if (l == r) {
sum[now] += val;
return;
}
int mid = (l + r) >> 1;
if (pos <= mid) modify(lc[now], l, mid, pos, val);
else modify(rc[now], mid+1, r, pos, val);
push_up(now);
}
int query(int now, int l, int r, int ql, int qr) {
if (!now) return 0;
if (ql <= l && qr >= r) return sum[now];
int mid = (l + r) >> 1;
int ans = 0;
if (ql <= mid) ans += query(lc[now], l, mid, ql, qr);
if (qr > mid) ans += query(rc[now], mid+1, r, ql, qr);
return ans;
}
struct Edge {
int to, next;
}e[N<<1];
void add(int u, int v) {
e[cnt].to = v;
e[cnt].next = h[u];
h[u] = cnt++;
}
void dfs(int u, int fa) {
rnk[dfn[u] = ++tot] = u, sz[u] = 1; dep[u] = dep[fa] + 1;
for (int i = h[u]; ~i; i = e[i].next) {
int v = e[i].to;
if (v == fa) continue;
dfs(v, u);
sz[u] += sz[v];
if (!son[u] || sz[v] > sz[son[u]])
son[u] = v;
}
}
void count(int u, int fa) {
for (int i = h[u]; ~i; i = e[i].next) {
int v = e[i].to;
if (v == fa || vis[v]) continue;
for (int j = dfn[v]; j < dfn[v] + sz[v]; j++) {
int deep = 2 * dep[u] + K - dep[rnk[j]];
int num = rnk[j];
if (deep >= 0 && a[u] * 2 - a[num] <= n && a[u] * 2 - a[num] >= 0) ans += query(rt[a[u] * 2 - a[num]], 1, n, 1, min(n, deep));
}
for (int j = dfn[v]; j < dfn[v] + sz[v]; j++) {
int num = rnk[j];
modify(rt[a[num]], 1, n, dep[num], 1);
}
}
modify(rt[a[u]], 1, n, dep[u], 1);
}
void dsu(int u, int fa, int keep) {
for (int i = h[u]; ~i; i = e[i].next) {
int v = e[i].to;
if (v == fa || son[u] == v) continue;
dsu(v, u, 0);
}
if (son[u]) dsu(son[u], u, 1), vis[son[u]] = 1;
count(u, fa);
if (son[u]) vis[son[u]] = 0;
if (!keep) {
for (int i = dfn[u]; i < dfn[u] + sz[u]; i++) {
int num = rnk[i];
modify(rt[a[num]], 1, n, dep[num], -1);
}
}
}
void solve() {
cin >> n >> K;
memset(h, -1, sizeof h);
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 2; i <= n; i++) {
int x; cin >> x;
add(i, x), add(x, i);
}
dfs(1, 0);
dsu(1, 0, 1);
cout << ans * 2 << endl;
}
signed main() {
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
#ifdef ACM_LOCAL
freopen("input", "r", stdin);
freopen("output", "w", stdout);
#endif
solve();
return 0;
}