【题目描述】
给你一个长度为n的数列A,请你计算里面有多少个四元组(a,b,c,d)满足:
a≠b≠c≠d,1≤a<b≤n,1≤c<d≤n,Aa<Ab,Ac>Ad
【输入格式】
输入文件第一行有一个整数N,第二行有N个整数A1,A2?An
【输出格式】
【输入1】
4
2 4 1 3
【输出1】
1
【输入2】
4
1 2 3 4
【输出2】
0
【数据约定】
15% n<=100
100%n<=50000
A在int范围里
嗯,第一眼看题就知道(以lkb的水平)这题不可做,果断骗分。
本来写了n^4的纯暴力的,然而担心会挂,于是想了点优化。可以通过鬼畜的预处理(先n^2任意找出每个数左边/右边比它大/小的数的个数),然后n^3枚举另外三个数,最后得了15分。(不过事后验证,其实不用优化的n^4的暴力也可以的15分。)
那么,这道题目的正解其实是树状数组+容斥原理。
首先,把输入数据离散化,范围变成1~n。
我们可以用四个树状数组维护Ai的左边、右边的数中大于、小于它的数的个数,也就相当于求出了一个可供a、b、c、d使用的数据库。
具体的方法就是:树状数组Ti表示数字i出现过的次数,比如要求比Ai小的数有多少个,就求sum(Ai-1)即可,求完之后记得把维护T[Ai]的值。至于求大于Ai的数,就用数的总数减去sum(Ai)(也就是有多少个小于等于Ai的数),即可。
求出这些数据之后,我们就可以把每个数的左边比它小的数的数量(其实也就是ab点对数)和左边比它大的数的数量(cd点对数)分别累加,然后把两个总和相乘,即可求出所有方案数。
但,还要注意到a!=b!=c!=d这个条件。在ab和cd分开考虑的情况下,可能会出现哪些相同的情况呢?不妨来打个表。
| a | b | c | D |
A | \ | 之前已排除 | smallLeft*smallRight | smallLeft*bigLeft |
B | \ | \ | bigRight*smallRight | bigRight*bigLeft |
C | \ | \ | \ | 之前已排除 |
D | \ | \ | \ | \ |
于是再根据容斥原理,把之前重复计算过的这些无效方案减去即可。
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define lowbit(x) (x & (-x))
const int maxn = 5e5 + 7;
struct Tnode {
int x, y;
//x是数
//y是下标
} a[maxn];
//离散化时用
int n;
int t[maxn]; //树状数组
long long smallLeft[maxn], bigLeft[maxn];
long long smallRight[maxn], bigRight[maxn];
//left和right表示方向
//small和big表示是储存比它大还是小的数的数量
bool cmp(Tnode i, Tnode j) {
return i.x < j.x;
}
//树状数组操作
int sum(int p) {
int res = 0;
while(p) {
res += t[p];
p -= lowbit(p);
}
return res;
}
void add(int p, int v) {
while(p <= n) {
t[p] += v;
p += lowbit(p);
}
}
int main() {
freopen("world.in", "r", stdin);
freopen("world.out", "w", stdout);
cin >> n;
for(int i = 1; i <= n; i++) {
cin >> a[i].x;
a[i].y = i;
}
//把a数组里面的值离散化
//变成b数组
sort(a + 1, a + n + 1, cmp);
int now = 0;
int b[maxn];
for(int i = 1; i <= n; i++) {
if(i == 1 || a[i].x != a[i - 1].x) ++now;
b[a[i].y] = now;
}
for(int i = 1; i <= n; i++) {
smallLeft[i] = sum(b[i] - 1); //左边比自己小的数
bigLeft[i] = (i - 1) - sum(b[i]); //左边比自己大的数
add(b[i], 1); //维护树状数组
}
//一点小技巧:个人感觉从右往左求有点烦,于是调转整个数组,
//照样从左求起,全部求完后再一次调转回来。
reverse(b + 1, b + n + 1);
memset(t, 0, sizeof t);
for(int i = 1; i <= n; i++) {
smallRight[i] = sum(b[i] - 1); //右边比自己小的数
bigRight[i] = (i - 1) - sum(b[i]); //右边比自己大的数
add(b[i], 1); //维护树状数组
}
reverse(smallRight + 1, smallRight + n + 1);
reverse(bigRight + 1, bigRight + n + 1);
long long ans1 = 0;
long long ans2 = 0;
long long sub = 0;
for(int i = 1; i <= n; i++) {
ans1 += smallLeft[i]; //累计ab点对数
ans2 += bigLeft[i]; //累计cd点对数
sub += bigRight[i] * smallRight[i] + //容斥原理减去无效部分
smallRight[i] * smallLeft[i] +
bigLeft[i] * bigRight[i] +
bigLeft[i] * smallLeft[i];
}
long long ans = ans1 * ans2 - sub;
cout << ans << endl;
return 0;
}