最大M子段和
题面
题目描述
N个整数组成的序列a[1],a[2],a[3],…,a[n],将这N个数划分为互不相交的M个子段,并且这M个子段的和是最大的。如果M >= N个数中正数的个数,那么输出所有正数的和。例如:-2 11 -4 13 -5 6 -2,分为2段,11 -4 13一段,6一段,和为26。
输入格式
第1行:2个数N和M,中间用空格分隔。N为整数的个数,M为划分为多少段。(2 <= N , M <= 5000)
第2 - N+1行:N个整数 (-10^9 <= a[i] <= 10^9)
输出格式
输出这个最大和
input
7 2
-2
11
-4
13
-5
6
-2
output
26
题解
动态规划:
f [ i ] [ j ] f[i][j] f[i][j]表示前 i i i个数划分 j j j个段,并且以第 i i i个数结尾得最大子段和。
不难得出, f [ i ] [ j ] = m a x ( f [ i − 1 ] [ j ] , f [ k ] [ j − 1 ] ) + a [ i ] ( 0 < k < i ) f[i][j]=max(f[i-1][j],f[k][j-1])+a[i]\ \ \ \ (0<k<i) f[i][j]=max(f[i−1][j],f[k][j−1])+a[i] (0<k<i).
其中的 f [ i − 1 ] [ j ] f[i-1][j] f[i−1][j]表示当前的第 i i i个数接在第 i − 1 i-1 i−1个数所在子段末尾, f [ k ] [ j − 1 ] f[k][j-1] f[k][j−1]表示独立存在于一个新的子段。
i , j , k i,j,k i,j,k三重循环枚举,复杂度 O ( n 3 ) O(n^3) O(n3),稳定超时。
我们需要在原来的基础上进行优化:
如果我们将 f [ i ] [ j ] f[i][j] f[i][j]在矩阵上联系,就会发现每一个点的转移是从“上一个点”和“在它左边且所有行数小于它的点”。具体是这样的:
比方说我们要转移的是黄色的点,“上一个点”就是蓝色的点,“在它左边且所有行数小于它的点”则是所有绿色的点。对于前者,我们可以直接取到。对于后者,我们可以在先前转移的时候就用数组维护和标记,这样在转移的时候直接调用即可。
具体是这样的,我们设
m
x
[
i
]
[
j
]
mx[i][j]
mx[i][j]表示
1
−
i
1-i
1−i行,第j列中决策
(
f
)
(f)
(f)的最大值。
即所有绿绿的点里面的最优值。
则不难得到: f [ i ] [ j ] = m a x ( m x [ i − 1 ] [ j − 1 ] , f [ i − 1 ] [ j ] ) + a [ i ] . f[i][j]=max(mx[i-1][j-1],f[i-1][j])+a[i]. f[i][j]=max(mx[i−1][j−1],f[i−1][j])+a[i].
因此得到代码( 0 m a r k s 0\ marks 0 marks):
#include<bits/stdc++.h>
using namespace std;
const int maxn=5010;
int n,m,sum=0;
int a[maxn];
int f[maxn][maxn];
int mx[maxn][maxn];
int main(void)
{
cin>>n>>m;
for (int i=1;i<=n;++i)
{
cin>>a[i];
sum+=a[i];
}
if (m>=n) { cout<<sum; exit(0); }
for (int i=1;i<=n;++i)
for (int j=1;j<=m;++j)
{
int x=f[i-1][j]+a[i];
int y=mx[i-1][j-1]+a[i];
f[i][j]=max(x,y);
mx[i][j]=max(mx[i-1][j],f[i][j]);
}
int ans=0;
for (int i=1;i<=n;++i) ans=max(ans,f[i][m]);
cout<<ans<<endl;
return 0;
}
基于题目的特殊性,这道题目卡空间。
通过观察,发现 f [ i ] [ j ] f[i][j] f[i][j]只和 f [ i − 1 ] [ j ] f[i-1][j] f[i−1][j]有关, m x [ i ] [ j ] mx[i][j] mx[i][j]只与 m x [ i − 1 ] [ j ( − 1 ) ] mx[i-1][j(-1)] mx[i−1][j(−1)]有关,要使用滚动数字。
滚动数组是让空间循环利用,因为每一层只需要用到上一层,根据取模的特殊性和周期性,就可以实现节约空间的效果。(这里 m o d 2 mod\ 2 mod 2)
注意: x x x xxx xxx & 1 \ 1 1 = ( x x x ) % 2 = (xxx) \%2 =(xxx)%2
C O D E CODE CODE
#include<bits/stdc++.h>
using namespace std;
#define LL long long
const LL maxn=5010;
LL n,m,sum=0;
LL a[maxn];
LL f[2][maxn];
LL mx[2][maxn];
int main(void)
{
cin>>n>>m;
for (LL i=1;i<=n;++i)
{
cin>>a[i];
sum+=a[i];
}
LL ans=0;
if (m>=n) { cout<<sum; exit(0); }
for (LL i=1;i<=n;++i)
for (LL j=1;j<=m;++j)
{
LL x=f[i-1&1][j]+a[i];
LL y=mx[i-1&1][j-1]+a[i];
f[i&1][j]=max(x,y);
mx[i&1][j]=max(mx[i-1&1][j],f[i&1][j]);
if (j==m) ans=max(ans,f[i&1][j]);
}
cout<<ans<<endl;
return 0;
}
欢迎点赞和指正不足哦!!