题目大意:
现在有一棵二叉树,所有非叶子节点都有两个孩子。在每个叶子节点上有一个权值(有n个叶子节点,满足这些权值为1..n的一个排列)。可以任意交换每个非叶子节点的左右孩子。
要求进行一系列交换,使得最终所有叶子节点的权值按照遍历序写出来,逆序对个数最少。
思路:
考虑交换两个相邻两个区间前后的变化,发现变化的只是分别在两个区间之内的逆序对,对外面的逆序对和某一个区间里面的逆序对个数是没有影响的,所以其它任何的交换也改变不了跨这两个区间的逆序对,所以对于每一个非叶子结点,用贪心的办法每次选择最小的逆序对个数的状态,只要每一个非叶子结点的状态都达到了最优,那么全局一定是最优秀的。
由于求逆序对需要用到权值线段树,所以这里采用权值线段树合并的方式(其实我是为了做线段树合并的题目才知道要用权值线段树)。然后不知道怎样统计逆序对的我用了一种很暴力的方式去统计,即对于两个区间选择size较小的那一个,然后对另一个区间的权值线段树一次一次来求逆序对个数,虽然很暴力,但是时间复杂度在启发式下均摊之后可以接受(在洛谷可以过,但是bzoj过不了)。
其实正确做法是在合并的时候顺带统计,对于每一段线段树内子树的合并,分别用两边的左右儿子的值交叉乘一下就可以得到不同状态下的逆序对的个数了,因为这样保证了每一个位置都被它前面(后面)的
logn
l
o
g
n
个区间计算了一次。
#include<bits/stdc++.h>
#define REP(i,a,b) for(int i=a;i<=b;++i)
#define mid ((l+r)>>1)
typedef long long ll;
using namespace std;
void File(){
freopen("bzoj2212.in","r",stdin);
freopen("bzoj2212.out","w",stdout);
}
const int maxn=4e6+10;
const int maxm=8e6+10;
int n,ch[maxn][2],val[maxn],tot,cnt_mp;
int root[maxn],lc[maxm],rc[maxm],cnt;
map<int,int>mp;
map<int,int>::iterator it;
ll ans,sum[maxm];
void init(int u){
int x;
scanf("%d",&x);
if(!x){
ch[u][0]=++tot;
ch[u][1]=++tot;
init(ch[u][0]);
init(ch[u][1]);
}
else val[u]=x,mp[x]=0;
}
void dfs(int u){
if(ch[u][0] && ch[u][1]){
dfs(ch[u][0]);
dfs(ch[u][1]);
}
else val[u]=mp[val[u]];
}
void update(int &rt,int l,int r,int pos){
if(!rt)rt=++cnt;
if(l==r)++sum[rt];
else{
if(pos<=mid)update(lc[rt],l,mid,pos);
else update(rc[rt],mid+1,r,pos);
sum[rt]=sum[lc[rt]]+sum[rc[rt]];
}
}
int query(int rt,int l,int r,int L,int R){
if(!rt)return 0;
if(L<=l && r<=R)return sum[rt];
int ret=0;
if(L<=mid)ret+=query(lc[rt],l,mid,L,R);
if(R>=mid+1)ret+=query(rc[rt],mid+1,r,L,R);
return ret;
}
ll sum1,sum2;
int merge(int x,int y){
if(!x || !y)return x+y;
int now=++cnt;
sum[now]=sum[x]+sum[y];
sum1+=sum[lc[x]]*sum[rc[y]];
sum2+=sum[rc[x]]*sum[lc[y]];
lc[now]=merge(lc[x],lc[y]);
rc[now]=merge(rc[x],rc[y]);
return now;
}
void solve(int u){
if(ch[u][0] && ch[u][1]){
solve(ch[u][0]);
solve(ch[u][1]);
sum1=sum2=0;
root[u]=merge(root[ch[u][0]],root[ch[u][1]]);
ans+=min(sum1,sum2);
}
else update(root[u],1,cnt_mp,val[u]);
}
int main(){
File();
scanf("%d",&n);
init(++tot);
for(it=mp.begin();it!=mp.end();++it)
mp[it->first]=++cnt_mp;
dfs(1);
solve(1);
printf("%lld\n",ans);
return 0;
}