题目地址
考虑枚举LCA,对于某个确定的 LCA ,比如说LCA是Z,假设它有k个儿子,也就是在它之下的k个子树,有贡献的点对(x,y),x和y一定是分别分布在这k个子树里不同的两个子树中,因为如果x,y和在同一个子树 S 中,Z就不可能是x和y的LCA,S肯定是x和y的公共祖先(不一定是最近的公共祖先),而Z是S的父亲,所以x和y的最近公共祖先轮不到Z。
假设当前遍历到第 i 棵子树,那么要统计的就是 这棵子树中的所有点,跟之前遍历的总共 i-1 棵 子树之间的贡献。再来回想一下,我们要求的是 v[x]+v[y]=2v[z] ,x就是当前枚举的点,z就是 当前确定的LCA ,那么v[y]也随之确定了,v[y]=2v[z]-v[x],再考虑第四个条件,x和y的距离不超过k,已知dx,dz,那么dy=k-dx+2*dz 。那么就是求 前 i-1棵子树中深度不超过dy且值为v[y]的点的个数。这个肯定是要用某种数据结构了,因为点权值域是[0,n],所以考虑用n+1 棵 线段树维护(每一种权值对应一颗),线段树的下标对应的是原树中点的深度。查询的时候,就是用v[y] 对应的那棵线段树求deep[z]+1,deep[z]+y]的区间和。当然线段树要采取动态开点的方式,否则空间爆炸。如果是单纯的暴力枚举LCA,再暴力的统计这个LCA的贡献,总体时间复杂度是
n
2
l
o
g
n
n^2log ^n
n2logn,需要用DSU ON TREE ,每次保留重儿子的贡献,减少了许多重复的计算。总体时间复杂度就会降到
n
∗
l
o
g
n
∗
l
o
g
n
n*log^n *log^n
n∗logn∗logn
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1e5 + 5;
typedef long long LL;
inline int read()
{
char c = getchar();
int x = 0, f = 1;
while (c < '0' || c > '9')
{
if (c == '-')
f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
x = x * 10 + c - '0', c = getchar();
return x * f;
}
int N; //树的结点个数
int col[MAXN]; //结点的权值
int hSon[MAXN]; // hSon[x] 代表 x的重儿子
int siz[MAXN]; //子树大小
int nowHson; //nowHnow 当前子树的重儿子(统计新子树 时 初始化
int deep[MAXN];
LL ans = 0;
vector<int> G[MAXN];
int root[MAXN];
int STcnt = 0;
struct node
{
int ls, rs;
int sum;
} tr[int(1e7)];
#define lson(x) tr[x].ls
#define rson(x) tr[x].rs
inline void pushup(int x)
{
tr[x].sum = tr[lson(x)].sum + tr[rson(x)].sum;
}
void update(int L, int R, int c, int l, int r, int &x)
{
if (!x)
x = ++STcnt;
if (L <= l && r <= R)
{
tr[x].sum += c;
return;
}
int mid = l + r >> 1;
if (L <= mid)
update(L, R, c, l, mid, lson(x));
if (R > mid)
update(L, R, c, mid + 1, r, rson(x));
pushup(x);
}
int query(int L, int R, int l, int r, int &x)
{
if (!x)
return 0;
if (L <= l && r <= R)
{
return tr[x].sum;
}
int mid = l + r >> 1;
if (L <= mid && R > mid)
return query(L, R, l, mid, lson(x)) + query(L, R, mid + 1, r, rson(x));
else if (L <= mid)
return query(L, R, l, mid, lson(x));
else if (R > mid)
return query(L, R, mid + 1, r, rson(x));
}
void dfs(int x, int fa)
{ //第一遍 DFS 处理出 所有结点的重儿子
siz[x] = 1;
deep[x] = deep[fa] + 1;
for (int v : G[x])
{
if (v == fa)
continue;
dfs(v, x);
siz[x] += siz[v];
if (siz[v] > siz[hSon[x]])
hSon[x] = v; //轻重链剖分
}
}
stack<int> st;
int k;
void Count(int x, int fa, int val, int z)
{
int dv = 2 * col[z] - col[x]; //目标权值
int dk = k - deep[x] + 2 * deep[z]; //目标最大深度
dk = min(dk, N);
if (val == 1) //加贡献
{
if (fa == z) //统计一条新链
{
while (!st.empty())
{ //把上一条链 的信息更新到 线段树中
int tx = st.top();
update(deep[tx], deep[tx], 1, 1, N, root[col[tx]]);
st.pop();
}
}
if (dv <= N && deep[z]+1<=dk && x!=z)
{
//cerr<<query(deep[z]+1, dk, 1, N, root[dv])<<endl;
ans += query(deep[z]+1, dk, 1, N, root[dv]); // 累加贡献 (之前所有链中 深度 在[deep[z],dk]范围内 权值 为dv的个数
}
st.push(x); //把当前链上的结点 先存在 栈中
}
else //消除贡献
{
update(deep[x], deep[x], -1, 1, N, root[col[x]]);
}
for (int v : G[x])
{
if (v == fa || v == nowHson)
continue;
Count(v, x, val, z);
}
}
void dfs2(int x, int fa, int opt)
{
for (int i = 0; i < G[x].size(); i++)
{
int to = G[x][i];
if (to == fa)
continue;
if (to != hSon[x])
dfs2(to, x, 0); //暴力统计轻边的贡献,opt = 0表示递归完成后消除对该点的影响
}
if (hSon[x])
dfs2(hSon[x], x, 1), nowHson = hSon[x]; //统计重儿子的贡献,不消除影响
Count(x, fa, 1, x); //暴力统计所有轻儿子的贡献
nowHson = 0;
//统计已经完成,但x子树遍历到的最后一条链还存在栈中,也要更新
while (!st.empty())
{ //把最后一条链 的信息更新到 线段树中
int tx = st.top();
update(deep[tx], deep[tx], 1, 1, N, root[col[tx]]);
st.pop();
}
if (!opt)
{ //如果需要删除贡献的话就删掉
Count(x, fa, -1, x);
//sum = 0; Mx = 0; //初始化子树相关信息
}
}
int main()
{
#ifdef DEBUG
freopen("1.in", "r", stdin);
freopen("1.out", "w", stdout);
#endif
N = read();
k = read();
// scanf("%d%d",&N,&k);
for (int i = 1; i <= N; i++)
{
//scanf("%d",&col[i]);
col[i] = read();
}
for (int i = 2; i <= N; i++)
{
int fa = read();
G[fa].push_back(i);
}
dfs(1, 0);
dfs2(1, 0, 0);
printf("%lld", ans * 2);//(x,y)和(y,x)都合法,但ans里只统计了(x,y),所以*2
return 0;
}