题目描述
在含有
n
{n}
n 个整数的序列
a
1
,
a
2
,
.
.
.
,
a
n
{a_1,a_2,...,a_n}
a1,a2,...,an 中,当且仅当
i
<
j
<
k
{i<j<k}
i<j<k 且
a
i
<
a
j
<
a
k
{a_i<a_j<a_k}
ai<aj<ak ,才称作thair
,求一个序列中thair
的个数。
输入格式
开始一行一个正整数 n {n} n。
以后一行 n {n} n 个整数 a 1 , a 2 , . . . , a n {a_1,a_2,...,a_n} a1,a2,...,an
输出格式
一行一个整数表示thair
的个数。
输入输出样例
输入 #1
4
2 1 3 4
输出 #1
2
输入 #2
5
1 2 2 3 4
输出 #2
7
思路
- 作为一个三元的上升序列,我们很容易想到子序列枚举中间的元素。
L
[
i
]
{L[i]}
L[i] 为
a
[
i
]
{a[i]}
a[i] 左边小于
a
[
i
]
{a[i]}
a[i] 的元素个数。
R
[
i
]
{R[i]}
R[i] 为
a
[
i
]
{a[i]}
a[i] 右边大于
a
[
i
]
{a[i]}
a[i] 的元素个数。
- 乘法原理 :以 a [ i ] {a[i]} a[i] 为中间元素的合法序列个数为 L [ i ] ∗ R [ i ] {L[i] * R[i]} L[i]∗R[i]。
- 我们可以拿权值线段树来维护一段值域内数的个数,借此来计算出 L {L} L 和 R {R} R 数组。
代码
// #include <bits/stdc++.h>
#include <iostream>
#include <algorithm>
using namespace std;
#define ls u<<1
#define rs u<<1|1
#define int long long
const int N = 100010;
struct node {
int l, r;
int sum;
} tr[N << 2];
int n, a[N];
int L[N], R[N];
void pushup(int u) {
tr[u].sum = tr[ls].sum + tr[rs].sum;
}
void build(int u, int l, int r) {
tr[u] = {l, r};
if(l == r) { tr[u].sum = 0; return ;}
int mid = l + r >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
pushup(u);
}
void modify(int u, int x) {
if(tr[u].l >= x && tr[u].r <= x) {
tr[u].sum += 1;
return ;
}
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) modify(ls, x);
else modify(rs, x);
pushup(u);
}
int query(int u, int l, int r) {
if(tr[u].l >= l && tr[u].r <= r) { return tr[u].sum; }
int mid = tr[u].l + tr[u].r >> 1;
int res = 0;
if(l <= mid) res = query(ls, l, r);
if(r > mid) res += query(rs, l, r);
return res;
}
signed main() {
cin >> n;
int mx = 0;
for (int i = 1; i <= n; i++) cin >> a[i], mx = max(mx, a[i]);
build(1, 1, mx);
for (int i = 1; i <= n; i++) {
modify(1, a[i]);
L[i] = query(1, 1, a[i] - 1);
}
build(1, 1, mx);
for (int i = n; i >= 1; i--) {
modify(1, a[i]);
R[i] = query(1, a[i] + 1, mx);
}
int res = 0;
for (int i = 1; i <= n; i++) res += L[i] * R[i];
cout << res << endl;
return 0;
}
M元上升子序列
DP + 树状数组
思路
-
f [ i ] [ j ] {f[i][j]} f[i][j]:以 a [ j ] {a[j]} a[j] 为结尾的长度为 i {i} i 的上升子序列的个数。
- 状态转移方程: f [ i ] [ j ] = ∑ k < j , a [ k ] < a [ j ] f [ i − 1 ] [ k ] {f[i][j] = \sum_{k<j,a[k]<a[j]}f[i-1][k]} f[i][j]=∑k<j,a[k]<a[j]f[i−1][k]
- 即:所有长度少一位,且结尾在当前之前,且结尾的大小更小,那么就能转移过来。
-
答案: ∑ i = 1 n f [ M ] [ i ] {\sum\limits_{i=1}^{n}f[M][i]} i=1∑nf[M][i],( M {M} M是长度, i {i} i 是结尾位置)
显然,这样复杂度是 O ( N 2 M ) {O(N^2M)} O(N2M)
因为 k < j , a [ k ] < a [ j ] {k<j,a[k]<a[j]} k<j,a[k]<a[j],所以离散化 a {a} a 数组,用树状数组维护。
- 外层枚举长度,建立树状数组。
- 内层枚举结尾位置,在 a [ k ] {a[k]} a[k]下标( b [ j ] {b[j]} b[j])处插入 f [ i − 1 ] [ j ] {f[i-1][j]} f[i−1][j]。
- 内层循环到 j {j} j, f [ i ] [ j ] + = s u m ( b [ j ] − 1 ) {f[i][j]+=sum(b[j]-1)} f[i][j]+=sum(b[j]−1),然后 a d d ( b [ j ] , f [ i − 1 ] [ j ] ) {add(b[j], f[i-1][j])} add(b[j],f[i−1][j]) 。
- j {j} j从小到大循环保证了 k < j {k<j} k<j,查询 f [ i − 1 ] [ b [ j ] − 1 ] {f[i-1][b[j]-1]} f[i−1][b[j]−1]的前缀和保证了 a [ k ] < a [ j ] {a[k]<a[j]} a[k]<a[j]。
const int M = 3;
int n;
int a[N], tr[N], b[N];
int f[M + 1][N];
int sum(int x) {
int res = 0;
for (; x; x -= x & -x) res += tr[x];
return res;
}
void add(int x, int c) {
for (; x <= n; x += x & -x) tr[x] += c;
}
void solve() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> a[i], b[i] = a[i], f[1][i] = 1;
sort(a + 1, a + 1 + n);
int m = unique(a + 1, a + 1 + n) - a - 1;
for (int i = 1; i <= n; i++)
b[i] = lower_bound(a + 1, a + 1 + m, b[i]) - a;
for (int i = 2; i <= M; i++) {
memset(tr, 0, sizeof tr);
for (int j = 1; j <= n; j++) {
f[i][j] = sum(b[j] - 1);
add(b[j], f[i - 1][j]);
}
}
int res = 0;
for (int i = 1; i <= n; i++) res += f[M][i];
cout << res << endl;
}
【变种】上升四元组
- 枚举 l l l,统计 j j j 为中间的三元组数量( v [ j ] v[j] v[j]: i < j < k , a [ i ] < a [ k ] < a [ j ] i<j<k,a[i]<a[k]<a[j] i<j<k,a[i]<a[k]<a[j])
class Solution {
public:
long long countQuadruplets(vector<int>& nums) {
int n = nums.size();
vector<int> v(n, 0);
long long res = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < i; j++)
if (nums[j] < nums[i]) res += v[j];
// count:j之前比a[i]小的数个数
// v[j]+=count:j作为中间值,i作为k,统计上三元组个数。
for (int j = 0, count = 0; j < i; j++) {
if (nums[j] > nums[i]) v[j] += count;
count += nums[j] < nums[i];
}
}
return res;
}
};