勉勉强强理解了线段树合并
一道线段树合并的模板题 其实按照dfs序暴力合并就行
其实这题可以直接用树状数组秒杀 只是刚开始没发现 写到一半才意识到(话说为什么老金一眼就看出来了)
可能自己对树上dfs序的各种操作还理解的不够深刻
于是 假装自己学会了线段树合并...
线段树合并解法的代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+5;
int n,cnt;
int ls[40*maxn],rs[40*maxn],tree[40*maxn],root[40*maxn],tot;
int h[maxn],a[maxn],mp[maxn],num,ans[maxn];
struct Edge
{
int y,next;
}edge[maxn];
void addedge(int x,int y)
{
edge[++num].y=y;
edge[num].next=h[x];
h[x]=num;
}
void push_up(int k)
{
tree[k]=tree[ls[k]]+tree[rs[k]];
}
int insert(int l,int r,int x)
{
int node=++tot;
if(l==r)
{
tree[node]=1;
return node;
}
int mid=(l+r)/2;
if(x<=mid)
ls[node]=insert(l,mid,x);
else rs[node]=insert(mid+1,r,x);
push_up(node);
return node;
}
int merge(int l,int r,int u,int v)
{
if(!u||!v) return u+v;
int node=++tot;
if(l==r)
{
tree[node]=tree[u]+tree[v];
return node;
}
int mid=(l+r)>>1;
ls[node]=merge(l,mid,ls[u],ls[v]);
rs[node]=merge(mid+1,r,rs[u],rs[v]);
push_up(node);
return node;
}
int query(int L,int R,int l,int r,int k)
{
if(L<=l&&r<=R)
return tree[k];
int mid=(l+r)/2;
int ans=0;
if(L<=mid) ans+=query(L,R,l,mid,ls[k]);
if(R>mid) ans+=query(L,R,mid+1,r,rs[k]);
return ans;
}
int dfs(int x)
{
for(int i=h[x];i;i=edge[i].next)
{
dfs(edge[i].y);
root[x]=merge(1,cnt,root[x],root[edge[i].y]);
}
ans[x]=query(a[x]+1,cnt,1,cnt,root[x]);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
mp[i]=a[i];
}
sort(mp+1,mp+1+n);
cnt=unique(mp+1,mp+1+n)-mp-1;
for(int i=1;i<=n;i++)
a[i]=lower_bound(mp+1,mp+1+cnt,a[i])-mp;
for(int i=2;i<=n;i++)
{
int x;
scanf("%d",&x);
addedge(x,i);
}
for(int i=1;i<=n;i++)
root[i]=insert(1,cnt,a[i]);
dfs(1);
for(int i=1;i<=n;i++)
printf("%d\n",ans[i]);
return 0;
}
树状数组解法的代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+5;
int h[maxn],tree[maxn],a[maxn],mp[maxn],ans[maxn],cnt,num;
struct Edge
{
int y,next;
}edge[maxn];
void addedge(int x,int y)
{
edge[++num].y=y;
edge[num].next=h[x];
h[x]=num;
}
int lowbit(int x)
{
return x&-x;
}
void add(int x)
{
while(x<=cnt)
{
tree[x]+=1;
x+=lowbit(x);
}
}
int get(int x)
{
int res=0;
while(x)
{
res+=tree[x];
x-=lowbit(x);
}
return res;
}
int dfs(int x)
{
int t=get(cnt)-get(a[x]);
add(a[x]);
for(int i=h[x];i;i=edge[i].next)
{
dfs(edge[i].y);
}
ans[x]=get(cnt)-get(a[x])-t;
}
int main()
{
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
mp[i]=a[i];
}
sort(mp+1,mp+1+n);
cnt=unique(mp+1,mp+1+n)-mp-1;
for(int i=1;i<=n;i++)
a[i]=lower_bound(mp+1,mp+1+cnt,a[i])-mp;
for(int i=2;i<=n;i++)
{
int x;
scanf("%d",&x);
addedge(x,i);
}
dfs(1);
for(int i=1;i<=n;i++)
printf("%d\n",ans[i]);
return 0;
}