石子合并
题目
Time Limit:1000MS Memory Limit:131072KB
- Description
现在有n堆石子,第i堆有ai个石子。现在要把这些石子合并成一堆,每次只能合并相邻两个,每次合并的代价是两堆石子的总石子数。求合并所有石子的最小代价。
- Input
第一行包含一个整数T(T<=50),表示数据组数。
每组数据第一行包含一个整数n(2<=n<=100),表示石子的堆数。
第二行包含n个正整数ai(ai<=100),表示每堆石子的石子数。
- Output
每组数据仅一行,表示最小合并代价。
- Sample Input
2
4
1 2 3 4
5
3 5 2 1 4
- Sample Output
19
33
分析
本题题意很简单,属于经典DP问题石子合并问题中的单线性合并。采用区间DP思想很容易得出,注意dp数组自身的初始化。
耗时:148ms
代码如下:
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
using namespace std;
int a[105];
int dp[105][105];
int sum[105];
const int M = 100000000;
int main()
{
int T, n;
scanf("%d", &T);
while(T--)
{
scanf("%d", &n);
sum[0] = 0;
for(int i=1;i<=n;i++)
{
scanf("%d", a+i);
sum[i] = sum[i-1] + a[i];
}
for(int l=2;l<=n;l++)
for(int i=1;i<=n-l+1;i++)
{
int j = i+l-1;
dp[i][j] = M;
for(int k=i;k<j;k++)
dp[i][j] = min(dp[i][j], dp[i][k] + dp[k+1][j] + sum[j]- sum[i-1]);
}
printf("%d\n", dp[1][n]);
}
return 0;
}
显然三层循环使得时间复杂度为O(N^3),根据四边形不等式的优化,可将时间复杂度降为O(N^2)关键在于s[i][j-1]<=s[i][j]<=s[i+1][j]的运用,此处链接:定理证明
耗时:16ms
代码如下:
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
using namespace std;
int a[105];
int dp[105][105];
int s[105][105];
int sum[105];
const int M = 100000000;
int main()
{
int T, n;
scanf("%d", &T);
while(T--)
{
scanf("%d", &n);
sum[0] = 0;
for(int i=1;i<=n;i++)
{
scanf("%d", a+i);
sum[i] = sum[i-1] + a[i];
dp[i][i] = 0;
s[i][i] = i;
}
for(int l=2;l<=n;l++)
for(int i=1;i<=n-l+1;i++)
{
int j = i+l-1;
dp[i][j] = M;
for(int k=s[i][j-1];k<=s[i+1][j];k++)
if(dp[i][j] > dp[i][k] + dp[k+1][j] + sum[j]- sum[i-1])
{
dp[i][j] = dp[i][k] + dp[k+1][j] + sum[j]- sum[i-1];
s[i][j] = k;
}
}
printf("%d\n", dp[1][n]);
}
return 0;
}
经过进一步探究发现,本题的时间复杂度可以再进一步优化为O(nlogn),其中涉及到了GarsiaWachs算法(证明详见TAOCP,麻麻问我为什么跪着敲代码……)
基本思想是通过树的最优性得到一个节点间深度的约束,之后证明操作一次之后的解可以和原来的解一一对应,并保证节点移动之后他所在的深度不会改变。
耗时:0ms
代码如下:
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
using namespace std;
int a[105];
const int M = 100000000;
int main()
{
int T, n;
int ans, temp;
scanf("%d", &T);
while(T--)
{
scanf("%d", &n);
for(int i=1;i<=n;i++)
scanf("%d", a+i);
a[0] = a[n+1] = M;
ans = 0;
while(n >= 2)
{
int i, j;
for(i = 2; i <= n; i++)
if(a[i-1] < a[i+1])
break;
temp = a[i-1] + a[i];
ans += temp;
for(j = i-1; j && temp > a[j-1]; j--)
a[j] = a[j-1];
a[j] = temp;
for(j = i; j <= n; j++)
a[j] = a[j+1];
n--;
}
printf("%d\n",ans);
}
return 0;
}