题意:
给定一个长度为 n n n 的数组 a a a,求有多少划分方式使得每段区间中出现次数为 1 1 1 的数字个数不大于 k k k,答案模 998244353 998244353 998244353。 ( n , k , a i ≤ 1 0 5 ) (n, k, a_i \leq 10^5) (n,k,ai≤105)
链接:
https://codeforces.com/contest/1129/problem/D
解题思路:
很容易得到
d
p
dp
dp 状态转移方程,记
f
[
i
]
f[i]
f[i] 为前缀
i
i
i 的划分答案,则
f
[
0
]
=
1
,
f
[
i
]
=
∑
j
=
1
i
f
[
j
−
1
]
∗
[
c
n
t
(
j
,
i
)
≤
k
]
f[0] = 1, f[i] = \sum\limits_{j = 1}^{i} f[j - 1] * [cnt(j, i) \leq k]
f[0]=1,f[i]=j=1∑if[j−1]∗[cnt(j,i)≤k]
其中,
c
n
t
(
j
,
i
)
cnt(j, i)
cnt(j,i) 表示区间
[
j
,
i
]
[j, i]
[j,i] 频率为
1
1
1 的数字个数,可以动态维护
i
i
i 作为右端点时每个
j
j
j 到
i
i
i 的
c
n
t
cnt
cnt 值。
i − 1 → i i - 1 \rightarrow i i−1→i,记 l a s t [ a [ i ] ] 、 l l a s t [ a [ i ] ] last[a[i]]、llast[a[i]] last[a[i]]、llast[a[i]] 为 a [ i ] a[i] a[i] 上次、上上次出现的位置,则相应更新操作为区间 [ l a s t [ a [ i ] ] + 1 , i ] [~last[a[i]] + 1, i~] [ last[a[i]]+1,i ] 加 1 1 1,区间 [ l l a s t [ a [ i ] ] + 1 , l a s t [ a [ i ] ] ] [~llast[a[i]] + 1, last[a[i]]~] [ llast[a[i]]+1,last[a[i]] ] 减 1 1 1,那么 c n t ( j , i ) cnt(j, i) cnt(j,i) 的值对应一个单点值。
然后分块维护 c n t ( j , i ) ≤ k cnt(j, i) \leq k cnt(j,i)≤k 对应的 f [ j − 1 ] f[j - 1] f[j−1] 的和即可。
参考代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define sz(a) ((int)a.size())
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 1e5 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;
const int maxm = 333;
int a[maxn], last[maxn], llast[maxn], f[maxn];
int id[maxn], li[maxm], ri[maxm], g[maxn], add[maxm], sum[maxm][maxn], ans[maxm];
int n, k;
void build(){
int len = sqrt(n);
for(int i = 1; i <= n; ++i) id[i] = (i - 1) / len + 1;
for(int i = 1; i <= id[n]; ++i) li[i] = (i - 1) * len + 1, ri[i] = i * len;
ri[id[n]] = n;
}
inline void MOD(int &x){
x %= mod, x += x < 0 ? mod : 0;
}
void pushDown(int p){
for(int i = li[p]; i <= ri[p]; ++i) sum[p][g[i]] = 0;
ans[p] = 0;
for(int i = li[p]; i <= ri[p]; ++i){
g[i] += add[p];
sum[p][g[i]] += f[i - 1], MOD(sum[p][g[i]]);
ans[p] += g[i] <= k ? f[i - 1] : 0, MOD(ans[p]);
}
add[p] = 0;
}
void update2(int l, int r, int val){
int p = id[l];
pushDown(p);
for(int i = l; i <= r; ++i){
sum[p][g[i]] -= f[i - 1], MOD(sum[p][g[i]]);
ans[p] -= g[i] <= k ? f[i - 1] : 0, MOD(ans[p]);
g[i] += val;
sum[p][g[i]] += f[i - 1], MOD(sum[p][g[i]]);
ans[p] += g[i] <= k ? f[i - 1] : 0, MOD(ans[p]);
}
}
void update(int l, int r, int val){
int p1 = id[l], p2 = id[r];
if(p1 == p2){
update2(l, r, val);
return;
}
for(int i = p1 + 1; i < p2; ++i){
if(val == 1 && k - add[i] >= 0) ans[i] -= sum[i][k - add[i]], MOD(ans[i]);
else if(val == -1 && k + 1 - add[i] >= 0) ans[i] += sum[i][k + 1 - add[i]], MOD(ans[i]);
add[i] += val;
}
update2(l, ri[p1], val);
update2(li[p2], r, val);
}
int query(){
int ret = 0;
for(int i = 1; i <= id[n]; ++i) ret += ans[i], MOD(ret);
return ret;
}
int main(){
ios::sync_with_stdio(0); cin.tie(0);
cin >> n >> k;
for(int i = 1; i <= n; ++i) cin >> a[i];
build();
f[0] = 1;
for(int i = 1; i <= n; ++i){
update(last[a[i]] + 1, i, 1);
if(last[a[i]]) update(llast[a[i]] + 1, last[a[i]], -1);
llast[a[i]] = last[a[i]], last[a[i]] = i;
f[i] = query();
}
cout << f[n] << endl;
return 0;
}