Problem
Solution
是一道很有意思的题目。
首先你可以写出一个最暴力的方程,设
f
[
i
]
[
j
]
f[i][j]
f[i][j]表示前j个位置分i段的最小代价
f
[
i
]
[
j
]
=
min
(
f
[
i
−
1
]
[
k
]
+
(
s
u
m
[
j
]
−
s
u
m
[
k
]
)
 
m
o
d
 
p
)
f[i][j]=\min (f[i-1][k]+(sum[j]-sum[k])\bmod p)
f[i][j]=min(f[i−1][k]+(sum[j]−sum[k])modp)
考虑优化,注意到模数很小,不妨按照前缀和模p的余数进行分组
记 g [ i ] [ j ] = min r ≤ n o w , s u m [ r ] ≡ j ( f [ i ] [ r ] ) g[i][j]=\min_{r\leq now,sum[r]\equiv j}(f[i][r]) g[i][j]=minr≤now,sum[r]≡j(f[i][r])
则可以做到 O ( n k p ) : O(nkp): O(nkp): f [ i ] [ j ] = min r = 0 p ( g [ i − 1 ] [ r ] + ( s u m [ j ] − r )   m o d   p ) f[i][j]=\min_{r=0}^p (g[i-1][r]+(sum[j]-r)\bmod p) f[i][j]=minr=0p(g[i−1][r]+(sum[j]−r)modp),套路地分类讨论一下即可用树状数组维护然后得到一个 O ( n k log p ) O(nk\log p) O(nklogp)的优秀算法。
∀ r ≤ s u m [ j ] f [ i ] [ j ] = min ( g [ i − 1 ] [ r ] − r + s u m [ j ] ) \forall_{r\leq sum[j]}f[i][j]=\min(g[i-1][r]-r+sum[j]) ∀r≤sum[j]f[i][j]=min(g[i−1][r]−r+sum[j])
∀ r > s u m [ j ] f [ i ] [ j ] = min ( g [ i − 1 ] [ r ] − r + s u m [ j ] + p ) \forall_{r>sum[j]}f[i][j]=\min(g[i-1][r]-r+sum[j]+p) ∀r>sum[j]f[i][j]=min(g[i−1][r]−r+sum[j]+p)
然后接下来怎么办呢。。首先答案肯定是同余与sum[n]的,考虑一下我们什么时候能得到最优解,假设sum表示前缀和模p的值,那么如果存在一个以sum[n]结尾的不降子序列长度大于等于k,那么就可以使得代价最小为sum[n],因为我们只需要分别以这些sum结尾来分段即可。
但是这还是可能被卡。那么再来讨论一下,如果存在k-1个sum值相同,那么就可以构造分别以这些sum值和n结尾来分段,这样就只有第一段和第k段做贡献了,而由于有了第一个判定,答案不可能为sum[n],而两端的贡献又不可能大于2p,所以答案就是sum[n]+p啦。根据抽屉原理,我们可以知道当 n > p ∗ ( k − 2 ) n> p*(k-2) n>p∗(k−2)时必定存在这样的情况,那么就都可以被判定掉了。如此以来n就缩小到了可做的范围内。
Code
#include <algorithm>
#include <cstring>
#include <cstdio>
#define rg register
#define lowbit(x) ((x)&(-(x)))
using namespace std;
typedef long long ll;
const int maxn=210,INF=0x3f3f3f3f;
template <typename Tp> inline int getmin(Tp &x,Tp y){return y<x?x=y,1:0;}
template <typename Tp> inline int getmax(Tp &x,Tp y){return y>x?x=y,1:0;}
template <typename Tp> inline void read(Tp &x)
{
x=0;int f=0;char ch=getchar();
while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
if(ch=='-') f=1,ch=getchar();
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
if(f) x=-x;
}
int n,k,mod,a[1000010],sum[1000010],rk[maxn],f[maxn][40010];
struct BIT{
int a[maxn];
BIT(){memset(a,0x3f,sizeof(a));}
void update(int p,int v){++p;for(;p<=mod;p+=lowbit(p)) getmin(a[p],v);}
int query(int p){int res=INF;++p;for(;p>0;p-=lowbit(p)) getmin(res,a[p]);return res;}
}L[maxn],R[maxn];
inline int pls(int x,int y){return x+y>=mod?x+y-mod:x+y;}
int lis()
{
memset(rk,0x3f,sizeof(rk));
for(rg int i=1;i<=n;i++) *upper_bound(rk+1,rk+k+1,sum[i])=sum[i];
return rk[k]<=sum[n];
}
void dp()
{
memset(f,0x3f,sizeof(f));
L[0].update(0,0);
for(rg int j=1;j<=n;j++)
for(rg int i=k;i;i--)
{
getmin(f[i][j],L[i-1].query(sum[j])+sum[j]);
getmin(f[i][j],R[i-1].query(mod-sum[j])+sum[j]+mod);
L[i].update(sum[j],f[i][j]-sum[j]);
R[i].update(mod-sum[j],f[i][j]-sum[j]);
}
printf("%d\n",f[k][n]);
}
int main()
{
read(n);read(k);read(mod);
if(k>n){printf("%d\n",INF);return 0;}
for(rg int i=1;i<=n;i++){read(a[i]);sum[i]=(sum[i-1]+a[i])%mod;}
if(lis()) printf("%d\n",sum[n]);
else if(n>k*mod) printf("%d\n",sum[n]+mod);
else dp();
return 0;
}