题目需要不断通过相邻小朋友的交换,使数列变成非递减数列。可知对于某个位置的数,其左侧
比其大的数和它必有一次交换,其右侧比其小的数和它必有一次交换。则它交换的次数是左侧比
它大的数的个数和右侧比它小的数的个数的和。加入和为x,可知其愤怒值为x*(1+x)/2。所以可以
用树状数组求解。
当然题目的意思,很明显就是冒泡排序、归并排序这些交换排序也可以满足的。但是冒泡排序的
时间复杂度大,不能满足题目要求的时间限制,归并排序基于二分,可成功用户求解。
例如:6 7 8 9 0 1 2 3 进行归并。
发现0<6.则6以及6右侧的7 8 9 都比0大,需要与0交换,则0的交换次数+4.
然后1<6.同样6以及6右侧的7 8 9都比1大,需要与1交换。
则我们会发现left[i] > right[j] 则j对应的元素需要加leftLen-i+1。
而对于left[i]<=right[j] 则右侧序列0 - j-1 都比left[i]。则left[i] 的交换次数需要加上0 ~j-1的元素。
树状数组AC代码:
#include <iostream>
#include <stdio.h>
#include <string.h>
using namespace std;
typedef long long ll;
const int maxn = 100005;
const int maxm = 1000005;
int n,h[maxn],cnt[maxm],sum[maxm];
int lowbit(int x) {
return x&(-x);
}
void deal() {
memset(sum,0,sizeof(sum));
memset(cnt,0,sizeof(cnt));
for(int i = 0; i < n; i++) {
//h[i] ~ maxm 都大于等于h[i]。
for(int k = h[i]; k < maxm; k=k+lowbit(k)) {
cnt[k] += 1;
}
int j = 0;
for(int k = h[i]; k > 0; k=k-lowbit(k)) {
j += cnt[k];
}
sum[i] = i-j+1; //算元素i左侧有多少大于等于元素i的。
}
memset(cnt,0,sizeof(cnt));
for(int i = n-1; i >= 0; i--) {
for(int k = h[i]; k < maxm; k=k+lowbit(k)) {
cnt[k] += 1;
}
for(int k = h[i]-1; k > 0; k=k-lowbit(k)) {
sum[i] += cnt[k];
}
}
}
int main() {
scanf("%d",&n);
for(int i = 0; i < n; i++) {
scanf("%d",&h[i]);
h[i]++;
//h[i]可能是0,导致lowbit(0) 一直是0,所以用树状数组一定要注意这点。
}
deal();
ll ans = 0;
for(int i = 0; i < n; i++) {
ans += ((ll)sum[i]+1)*(ll)sum[i]/2;
}
printf("%lld\n",ans);
return 0;
}
归并排序解法:
#include <iostream>
#include <stdio.h>
#include <string.h>
using namespace std;
typedef long long ll;
const int maxn = 100005;
int arr[maxn],cnt1[maxn],tmp[maxn],cnt2[maxn];
void mergePass(int L1,int R1,int L2,int R2) {
int i,j,k;
i = L1; //i指向左区间左端点
j = L2; //j指向右区间右端点
k = 0;
while(i<=R1 && j<=R2) {
if(arr[i]>arr[j]) {
cnt2[k] = cnt1[j] + (R1-i+1);
tmp[k++] = arr[j++];
}
else {
cnt2[k] = cnt1[i] + (j-L2);
tmp[k++] = arr[i++];
}
}
while(i<=R1) {
cnt2[k] = cnt1[i] + (R2-L2+1);
tmp[k++] = arr[i++];
}
while(j<=R2) {
cnt2[k] = cnt1[j];
tmp[k++] = arr[j++];
}
for(int i = L1; i <= R2; i++) {
arr[i] = tmp[i-L1];
cnt1[i] = cnt2[i-L1];
}
}
void mergeSort(int left,int right) {
if(left<right) {
int mid = (left+right)/2;
mergeSort(left,mid);
mergeSort(mid+1,right);
mergePass(left,mid,mid+1,right);
}
}
int main() {
int n;
while(~scanf("%d",&n)) {
for(int i = 0; i < n; i++) {
scanf("%d",&arr[i]);
}
memset(cnt1,0,sizeof(cnt1));
memset(cnt2,0,sizeof(cnt2));
mergeSort(0,n-1);
ll ans = 0;
for(int i = 0; i < n; i++) {
ans += ((ll)cnt1[i]+1)*(ll)cnt1[i]/2;
}
printf("%lld\n",ans);
}
return 0;
}