【题目链接】
http://acm.hust.edu.cn/vjudge/problem/viewProblem.action?id=19461
【解题报告】
题目大意:给你一个N元序列,每个元有一个分数,每次可以从左边拿或者从右边拿任意多。A,B轮流拿,拿完为止,求问A最多比B多拿多少。其中N<=100.
因为两个人都足够聪明,因此当他们任意一人面临一个(i,j)的局面时,因为序列总和sum[j]-sum[i-1]是不变的,所以一定会选择拿完之后另一个人面临局面(k,l)得分最少。这是类似于博弈论的思想。
所以我们这样设计dp状态:
dp[i][j]表示面临(i,j)局面的人最多可以拿多少分
那么它可以转移到
S={ dp(i+1,j), dp(i+2,j) … dp(j,j) , dp(i,j-1) , dp(i,j-2) … , dp(i,i) , 0 }
一口气拿完时,另一个人面临的局面就是0.
所以状态转移方程就是dp[i][j]=sum-min;
DFS区间更新即可。
需要注意的是这样做的时间复杂度是O(n^3),仍然有很大的优化空间。
优化:
对于min{ dp(i+1,j), dp(i+2,j) ... dp(j,j) , dp(i,j-1) , dp(i,j-2) ... , dp(i,i) }
设f(i+1,j)=min{ dp(i+1,j), dp(i+2,j) ... dp(j,j) }
设g(i.j-1)=min{ dp(i,j-1) , dp(i,j-2) ... , dp(i,i) }
那么dp[i][j]=sum-min{ f(i+1,j),g(i,j-1),0 }
其中f和g均可以通过递推得出:
f[i][j]=min{ f[i+1][j],dp[i][j] }
g[i][j]=min{ g[i][j-1],dp[i][j] }
所以时间复杂度被降到了O(n^2)
【参考代码】
1.O(N^3)
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<stack>
#include<queue>
#include<vector>
#include<map>
using namespace std;
const int INF=2e9+1e8;
int n;
int dp[100+10][100+10];
int a[100+10],sum[100+10];
int DFS( int i ,int j )
{
if( dp[i][j]!=INF )return dp[i][j];
int minn=INF;
for( int ii=i+1; ii<=j; ii++ )minn=min( minn, DFS( ii,j ) );
for( int jj=j-1; jj>=i; jj-- )minn=min( minn, DFS( i,jj ) );
minn=min( minn,0 ); //i~j全部拿完的状况
return dp[i][j]=sum[j]-sum[i-1]-minn;
}
int main()
{
while( ~scanf("%d",&n) && n )
{
for( int i=1; i<=n; i++ )scanf("%d",&a[i]);
for( int i=1; i<=n; i++ )
for( int j=1; j<=n; j++ )
{
if( i==j )dp[i][j]=a[i];
else dp[i][j]=INF;
}
memset(sum,0,sizeof(sum));
for( int i=1; i<=n; i++ )sum[i]=sum[i-1]+a[i];
// dp[i][j]=sum(i,j)-min( dp[i+1][j], dp[i+2][j],...dp[j][j], dp[i][j-1],...,dp[i][i] );
printf( "%d\n",2*DFS(1,n)-sum[n] );
}
return 0;
}
2.O(N^2)
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<stack>
#include<queue>
#include<vector>
#include<map>
using namespace std;
const int INF=2e9+1e8;
int n;
int dp[100+10][100+10],f[100+10][100+10],g[100+10][100+10];
int a[100+10],sum[100+10];
int main()
{
while( ~scanf("%d",&n) && n )
{
for( int i=1; i<=n; i++ )scanf("%d",&a[i]);
memset(sum,0,sizeof sum);
for( int i=1; i<=n; i++ )sum[i]=sum[i-1]+a[i];
memset(f,0,sizeof f);
memset(g,0,sizeof(g));
for( int i=n; i>=1; i-- )
for( int j=i; j<=n; j++ )
{
int temp=min( f[i+1][j],g[i][j-1] );
temp=min(0,temp);
dp[i][j]=sum[j]-sum[i-1]-temp;
f[i][j]=min( f[i+1][j],dp[i][j] );
g[i][j]=min( g[i][j-1],dp[i][j] );
}
printf( "%d\n",2*dp[1][n]-sum[n] );
}
return 0;
}