洛谷p1637题目传送门:
上代码
很容易想到 三元上升子序列的个数=每个数前比它小的数的个数*每个数后比它大的数的个数 之和
我们只需要维护 每个数前比它小的数的个数 和 每个数后比它大的数的个数 并记录下来就可以了,
那么怎样用线段树去维护呢?以一组数据 An={ 1,4,5,3 } 为例:
首先我们开个桶sum[n]记录每个数的出现次数,那么一开始这个桶就是 0 0 0 0 0
插入A1=1,此时的桶为 1 0 0 0 0 ,查询A1前有没有比它小的数: 1 0 0 0 0,好吧一个都没有,记smaller[1]=0;
插入A2=4,此时的桶为 1 0 0 1 0 ,查询A2前有没有比它小的数:(1 0 0)1 0,发现此时1比4小,smaller[2]=1;
插入A3=5,此时的桶为 1 0 0 1 1 ,查询A3前有没有比它小的数:(1 0 0 1)1,发现此时1,4比5小,smaller[3]=2;
插入A4=3,此时的桶为 1 0 1 1 1 ,查询A4前有没有比它小的数:(1 0)1 1 1,发现此时1比3小,smaller[4]=1;
按上面的操作步骤,即每次更新,将sum[A[n]]+1,然后sum[1]~sum[A[n]-1]的和即是small[n]的值,这也是用线段树求逆序对的方法
那么用同样的方法求出所有数后比它大的数,得到bigger[1~4],最后ans=ans+smaller[i]*bigger[i],1<=i<=4
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 3e4 + 10;
//nlongn 实现
int n;
ll a[N], b[N];
struct node {
ll l;
ll r;
ll w;
} tr[N << 2];
ll smaller[N];
ll bigger[N];
void pushup(ll u) {
tr[u].w = tr[u << 1].w + tr[u << 1 | 1].w;
}
void build(ll u, ll l, ll r) {
tr[u] = {l, r, 0};
if (l == r) return;
ll mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
return;
}
void add(ll u, ll x) {
if (tr[u].l == tr[u].r && tr[u].l == x) {
tr[u].w++;
return;
}
ll mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) add(u << 1, x);
else add(u << 1 | 1, x);
pushup(u);
}
void clear(ll u, ll x) {
if (tr[u].l == tr[u].r && tr[u].l == x) {
tr[u].w = 0;
return;
}
ll mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) clear(u << 1, x);
else clear(u << 1 | 1, x);
pushup(u);
}
ll query(ll u, ll l, ll r) {
if (tr[u].l >= l && tr[u].r <= r) {
return tr[u].w;
}
ll s = 0;
ll mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) s += query(u << 1, l, r);
if (r > mid) s += query(u << 1 | 1, l, r);
return s;
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%lld", &a[i]), b[i] = a[i];
sort(b + 1, b + 1 + n);
int last = unique(b + 1, b + 1 + n) - b - 1;
for (int i = 1; i <= n; i++) {
a[i] = lower_bound(b + 1, b + 1 + last, a[i]) - b;
}
build(1, 1, last);
for (int i = 1; i <= n; i++) {
add(1, a[i]);
ll te = 0;
te = query(1, 1, a[i] - 1);
smaller[i] = te;
}
for (int i = 1; i <= last; i++) clear(1, i);
for (int i = n; i >= 1; i--) {
add(1, a[i]);
ll te = 0;
te = query(1, a[i] + 1, last);
bigger[i] = te;
}
ll res = 0;
// for (int i = 1; i <= n; i++) cout << smaller[i] << ' ';
// cout << "\n";
// for (int i = 1; i <= n; i++) cout << bigger[i] << ' ';
// cout << "\n";
for (int i = 1; i <= n; i++) res += smaller[i] * bigger[i];
printf("%lld\n", res);
}