总时间限制: 10000ms 单个测试点时间限制: 1000ms
内存限制: 65536kB
描述
给定N个数的序列a1,a2,...aN,定义一个数对(ai, aj)为“重要逆序对”的充要条件为 i < j 且 ai > 2aj。求给定序列中“重要逆序对”的个数。
输入
第一行为序列中数字的个数N(1 ≤ N ≤ 200000)。
第二行为序列a1, a2 ... aN(0 ≤a ≤ 10000000),由空格分开。输出
输出一个整数,为给序列中“重要逆序对”的个数。
样例输入
10 0 9 8 7 6 5 4 3 2 1样例输出
16提示
如果使用printf输出long long类型,请用%lld
数据范围
对于40%的数据,有N ≤ 1000。
本题的典型解法是采用归并排序。相似题目为“求逆序对数”。
在"求逆序对数"中,归并排序的核心代码如下
void merge(int *s, int *temp, int startIndex, int endIndex, int mid)
{
int i = startIndex, j = mid + 1, k = startIndex;
int pointer = startIndex;
while(i <= mid && j <= endIndex)
{
if(s[i] > s[j])
{
temp[k] = s[j];
sum += mid - i + 1;//sum表示逆序对数
j ++;
}
else
{
temp[k] = s[i];
i ++;
}
k ++;
}
while(i <= mid)
{
temp[k ++] = s[i ++];
}
while(j <= endIndex)
{
temp[k ++] = s[j ++];
}
for(int i = startIndex; i <= endIndex; i ++)
{
s[i] = temp[i];
}
}
但是本题所求为“重要逆序对数”,因此我在上机时的做法如下:
void merge(int *s, int *temp, int startIndex, int endIndex, int mid)
{
int i = startIndex, j = mid + 1, k = startIndex;
int pointer = startIndex;
while(i <= mid && j <= endIndex)
{
if(s[i] > s[j])
{
temp[k] = s[j];
if(s[i] > 2 * s[j])
sum += mid - i + 1;
else
{
for(int t = i + 1; i <= mid; i ++ )
{
if(s[t] > 2 * s[j];
{
sum += mid - t + 1;
break;
}
}
}
j ++;
}
else
{
temp[k] = s[i];
i ++;
}
k ++;
}
while(i <= mid)
{
temp[k ++] = s[i ++];
}
while(j <= endIndex)
{
temp[k ++] = s[j ++];
}
for(int i = startIndex; i <= endIndex; i ++)
{
s[i] = temp[i];
}
}
但是很不幸的是,TLE了。回来后,室友提醒,在里面加了一重循环后,时间复杂度就不是O(nlogn)了,在最坏情况下,应该是O(n^2 logn)。上机时信心满满地以为加了break应该不会TLE,但是...
室友的做法是建立一个虚拟指针pointer,从startIndex开始,一直指到mid结束。时间复杂度并不会增加。
譬如归并(0, 6, 7, 8, 9)和(1, 2, 3, 4, 5)时,s[i]指向左边部分,s[j]指向右边部分。当pointer指向6,s[j] = 1 时,2 * 1<6;因此6右边的7,8,9都是重要逆序;pointer指向6,s[j] = 2时,2 * 2 < 6;因此,6右边的7,8,9都是重要逆序;pointer指向6,s[j] = 3时,2 * 3 = 6,因此需要加pointer右移一位,此时pointer指向的是7,2 * 3 < 7,因此 7右边的8,9都是重要逆序......以此类推,对于每一个s[j],都将pointer移动至s[pointer] > 2 * s[j]的位置i处。
显然pointer最多只需要移动 n / 2次(从startIndex到mid处),因此不会增加复杂度。
完整代码如下:
#include <iostream>
using namespace std;
long long sum = 0;
void merge(int *s, int *temp, int startIndex, int endIndex, int mid)
{
int i = startIndex, j = mid + 1, k = startIndex;
int pointer = startIndex;
while(i <= mid && j <= endIndex)
{
if(s[i] > s[j])
{
temp[k] = s[j];
while(s[pointer] <= 2 * s[j] && pointer <= mid)
{
pointer ++;
}
if(pointer != mid + 1)
{
sum += mid - pointer + 1;
}
j ++;
}
else
{
temp[k] = s[i];
i ++;
}
k ++;
}
while(i <= mid)
{
temp[k ++] = s[i ++];
}
while(j <= endIndex)
{
temp[k ++] = s[j ++];
}
for(int i = startIndex; i <= endIndex; i ++)
{
s[i] = temp[i];
}
}
void mergeSort(int *s, int *temp, int startIndex, int endIndex)
{
int mid = (startIndex + endIndex) / 2;
if(startIndex < endIndex)
{
mergeSort(s, temp, startIndex, mid);
mergeSort(s, temp, mid + 1, endIndex);
merge(s, temp, startIndex, endIndex, mid);
}
}
int s[200005] = {};
int temp[400010] = {};
int main(){
int n;
cin >> n;
for(int i = 0; i < n; i ++)
{
cin >> s[i];
}
mergeSort(s, temp, 0, n - 1);
cout << sum << endl;
return 0;
}
如果代码有问题,还望指出。