题目描述
一天,Chika 对大小接近的点对产生了兴趣,她想搞明白这个问题的树上版本,你能帮助她吗?Chika 会给 你一棵有根树,这棵树有 n 个结点,被编号为 1 n,1 号结点是根。每个点有一个权值,i 号结点的权值为 a[i]。如果 u 是 v 的祖先结点,并且 abs(a[u]−a[v]) ≤K,那么 (u,v) 被称作一个“** 大小接近的点对 **”。 对于树上的每个结点 i,你都需要计算以其为根的子树中的“大小接近的点对”的数量。你需要知道:
(1) abs(x) 代表 x 的绝对值。
(2) 每个结点都是其自身的祖先结点.
输入
输入文件的第一行包含两个整数 n (1≤n≤105) 和 k (1≤k≤109),代表树中结点总数, 以及“大小接近的点对”的大小之差的上界。
第二行包含 n 个整数,第 i 个整数是 a[i] (1≤ a[i] ≤109),代表 i 号结点的权值。
第三行包含 n−1 个整数,第 i 个整数是 i+1 号结点的父结点。
输出
输出应该包含n行,每一行包括一个整数。第i行的整数代表以i为根的子树中的“大小接近的点对”的数量。
样例输入 Copy
7 5 2 4 4 1 4 6 4 1 2 3 1 2 3
样例输出 Copy
19 11 5 1 1 1 1
思路:这道题如果暴力的做法就是遍历树,每次返回一个数组,然后二分求合法答案的个数TLE。zdw大佬用树状数组代替的返回的数组,核心思想就是先将所有数离散化,记录他们的位置,然后遍历树的时候,首先先进行树状数组的查询操作,也就是把以前树种的”合法答案“(其实是不合法的,不在这个节点的子树上),然后把这个点的值加入树状数组,继续遍历树,返回到这个节点时,再进行树状数组的查询(它的子树都加入到树状中了),然后减去第一次查询的结果即可。
代码如下:
#include<bits/stdc++.h>
#define ll long long
#define N 300010
using namespace std;
vector<ll>v[100010];
map<ll,int> u;
int n,k;
ll a[100010],ans[100010],l[N],r[N],Sum[N];
ll low_bit(ll x)
{
return x&(-x);
}
void add(ll x)
{
for(ll i=x+1; i<N; i+=low_bit(i))
Sum[i]++;
}
ll getsum(ll x)
{
ll sum=0;
for(ll i=x+1; i>0; i-=low_bit(i))
sum+=Sum[i];
return sum;
}
ll dfs(int x)
{
ll sum=0;
sum-=getsum(r[x])-getsum(l[x]); //先用树状数组记录一下已经出现过的(l[x],r[x])的个数,因为算当前子树时,不能被这些值影响,所以先减去
add(a[x]); //只加这个节点的位置,所以查询的时候左右端点并不影响结果
for(int i=0; i<v[x].size(); i++) //dfs遍历子树
sum+=dfs(v[x][i]); //当然要累加了
sum+=getsum(r[x])-getsum(l[x]);//现在这个节点的子树已经全部都在树状数组里了,所以我们应该用这个值减去不是这个子树的节点,由于上面已经处理过了,所以这里直接加就好
return ans[x]=sum;//赋值
}
int main()
{
scanf("%d%d",&n,&k);
for(int i=1; i<=n; i++) //输入权值
{
scanf("%lld",&a[i]);
l[i]=a[i]-k; //左边能到哪
r[i]=a[i]+k; //右边能到哪
u[a[i]]=u[l[i]]=u[r[i]]=1; //map标记一下
}
int cnt=1;
map<ll,int>::iterator it;
for(it=u.begin(); it!=u.end(); ++it) //把所有点按顺序排好,并记录是第几个点
it->second=cnt++;
//离散化:a[i] 代表这个节点的权值在所有节点,以及左右端点的从小到大的排名
for(int i=1; i<=n; i++) //重新对权值以及两个端点进行处理
{
l[i]=u[l[i]]; //左端现在是第几个,从小到大
r[i]=u[r[i]]; //同理
a[i]=u[a[i]];
}
for(int i=1; i<n; i++) //建树
{
int x;
scanf("%d",&x);
v[x].push_back(i+1);
}
ll sum=dfs(1);//跑一遍树
for(int i=1; i<=n; i++)
printf("%lld\n",ans[i]);
return 0;
}