题面
题意
给出一个长度为n的序列,将它划分成一个或多个子串,使每个串中只出现过一次的数字个数小于等于m,求方案个数。
做法
记dp[i]表示前i个的答案,
f
(
i
)
f(i)
f(i)表示
i
i
i到
n
n
n(
n
n
n表示此时的右端点)的只出现一次的数字个数,则
d
p
[
n
]
=
∑
f
(
i
)
<
=
m
d
p
[
i
−
1
]
dp[n]=\sum_{f(i)<=m}dp[i-1]
dp[n]=∑f(i)<=mdp[i−1]
因此只要维护符合条件的dp值之和即可,发现在此时维护的序列后面加上一个数后,会使一个区间的f值+1,一个区间的f值-1,因此可以用分块进行维护,维护每个块区间加了多少,此时的答案是多少。
代码
#include<bits/stdc++.h>
#define ll long long
#define S 410
#define N 100100
#define M 998244353
using namespace std;
ll n,m,s,num[N],ad[S],pos[N],last[N],dp[N],ans[S],cnt[N];
int sum[S][N];
inline ll get(ll u){return u/s+1;}
template<class T,class TT>inline void Add(T &u,TT v){u=((ll)u+v)%M;}
inline void chg(ll u,ll v)
{
ll t=get(u);
Add(sum[t][cnt[u]],M-dp[u]);
if(cnt[u]+ad[t]<=m) Add(ans[t],M-dp[u]);
cnt[u]+=v;
Add(sum[t][cnt[u]],dp[u]);
if(cnt[u]+ad[t]<=m) Add(ans[t],dp[u]);
}
inline void add(ll u,ll v,ll w)
{
if(u>v) return;
ll i,p=get(u),q=get(v);
if(p==q)
{
for(i=u;i<=v;i++)
{
chg(i,w);
}
return;
}
for(i=u;get(i)==p;i++) chg(i,w);
for(i=v;get(i)==q;i--) chg(i,w);
for(i=p+1;i<q;i++)
{
if(w>0) if(m-ad[i]>=0) Add(ans[i],M-sum[i][m-ad[i]]);
ad[i]+=w;
if(w<0) if(m-ad[i]>=0) Add(ans[i],sum[i][m-ad[i]]);
}
}
int main()
{
ll i,j,t;
cin>>n>>m;
s=sqrt(n);
for(i=1;i<=n;i++)
{
scanf("%lld",&num[i]);
}
dp[0]=sum[1][0]=ans[1]=1;
for(i=1;i<=n;i++)
{
last[i]=pos[num[i]];
add(last[last[i]],last[i]-1,-1);
add(last[i],i-1,1);
for(t=get(i),j=i-1;j>=0&&get(j)==t;j--) if(cnt[j]+ad[t]<=m) Add(dp[i],dp[j]);
for(j=t-1;j>=1;j--) Add(dp[i],ans[j]);
Add(sum[t][0],dp[i]);
if(ad[t]<=m) Add(ans[t],dp[i]);
pos[num[i]]=i;
}
cout<<dp[n];
}