LEQ
https://atcoder.jp/contests/abc221/tasks/abc221_e
题目大意:给出一个长度为 n n n 的整数序列 a a a ,问有多少个子序列(无需连续) a 1 , a 2 . . . a k a_1,a_2...a_k a1,a2...ak ,满足 a 1 ≤ a k a_1≤a_k a1≤ak 。
那么我们发现可以将问题看成对于每个正序对去算一个值最后求和。那么我们这里正序对维护值的时候就不再是简单的加一了。假设
a
i
a_i
ai 作为正序对小的那一方,与
a
i
a_i
ai 可构成正序对的元素为
{
a
j
1
,
a
j
2
.
.
.
a
j
k
}
\{a_{j1},a_{j2}...a_{jk}\}
{aj1,aj2...ajk} ,那么其满足条件的子序列个数即为:
2
j
1
−
i
−
1
+
2
j
2
−
i
−
1
.
.
.
+
2
j
k
−
i
−
1
=
∑
g
=
0
k
2
j
g
2
i
+
1
2^{j_1-i-1}+2^{j_2-i-1}...+2^{j_k-i-1}=\frac{\sum\limits_{g=0}^{k}2^{j_g}}{2^{i+1}}
2j1−i−1+2j2−i−1...+2jk−i−1=2i+1g=0∑k2jg
所以我们发现我们只要对于求正序对去变个形,单点更新的是 2 下 标 2^{下标} 2下标 ,其它一致,最后再多除一个 2 i + 1 2^{i+1} 2i+1 即算出对于 a i a_i ai 作为 a 1 a_1 a1 的所有满足条件的答案,那么我们自然就可以遍历 a i a_i ai 算出总答案。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 3e5 + 10;
const ll mod = 998244353;
int n, a[N];
ll er[N], inv[N], ans;
ll sum[N << 5];
int lson[N << 5], rson[N << 5], root, tot;
ll ksm(ll base, ll n) {
ll ans = 1;
while(n) {
if (n & 1ll) ans = ans * base % mod;
base = base * base % mod;
n >>= 1ll;
}
return ans;
}
void init(int n) {
er[0] = 1;
for (int i = 1; i <= n; ++i) {
er[i] = 2ll * er[i - 1] % mod;
}
inv[n] = ksm(er[n], mod - 2);
for (int i = n - 1; i >= 0; --i) {
inv[i] = inv[i + 1] * 2ll % mod;
}
}
void update(int &rt, int l, int r, int pos, int idx) {
if (!rt) rt = ++tot;
if (l == r) {
sum[rt] = (sum[rt] + er[idx]) % mod;
return ;
}
int mid = (1ll * l + r) >> 1;
if (pos <= mid) update(lson[rt], l, mid, pos, idx);
else update(rson[rt], mid + 1, r, pos, idx);
sum[rt] = (sum[lson[rt]] + sum[rson[rt]]) % mod;
}
ll query(int rt, int l, int r, int L, int R) {
if (L <= l && r <= R) return sum[rt];
int mid = (1ll * l + r) >> 1;
ll ans = 0;
if (mid >= L) ans = query(lson[rt], l, mid, L, R);
if (mid < R) ans = (ans + query(rson[rt], mid + 1, r, L, R)) % mod;
return ans;
}
int main() {
#ifndef ONLINE_JUDGE
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
init(300001);
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", &a[i]);
}
for (int i = n; i > 0; --i) {
ll cur = query(root, 1, 1e9, a[i], 1e9) * inv[i + 1] % mod;
ans = (ans + cur) % mod;
update(root, 1, 1e9, a[i], i);
}
printf("%lld", ans);
}