Problem
lx.lanqiao.cn/problem.page?gpid=T414
Reference
blog.csdn.net/u014800748/article/details/45750737
www.cnblogs.com/jiu0821/p/4493497.html
Meaning
一条直线上 n 堆石子,每次可以合并相邻两堆并形成新的一堆,花费为原来的两堆石子的石子数的和。
求将 n 堆合并成一堆的最小总花费。
Analysis
区间DP,但朴素区间DP会超时,要用四边形不等式优化。
(相关介绍、证明见参考博客)
dp[i][j]:合并区间 [ i , j ] 的所有石子堆的最小花费
s[i][j]:让 dp[i][j] 取得最小值的那个区间断点的位置
状态转移:dp[i][j] = min { dp[i][k] + dp[k+1][j] + sum ( i , j ) | s[i][j-1] <= k <= s[i+1][j] }
其中,sum ( i , j ) 表示区间 [ i , j ] 的所有石子堆的总石子数;
k 就是断点的位置,朴素的做法时,它的范围是 [ i+1 , j-1 ],但优化后变成 [ s(i,j-1) , s(i+1,j) ]。
因为能证明:s(i,j-1) <= s(i,j) <= s(i+1,j),而 s(i,j) 就是当前 k 最终要取的值,所以 k 只需要扫一遍这个区间。
Source code
#include <cstdio>
#include <cstring>
using namespace std;
const int N = 1000;
int a[N+1]; // 石子数
int sum[N+1]; // 石子数的前缀和
int dp[N+1][N+1]; // 合并区间[i,j]的所有石子堆的最小花费
int s[N+1][N+1]; // 让 dp[i][j] 取得最小值的那个区间断点的位置
int main()
{
int n;
scanf("%d", &n);
sum[0] = 0;
for(int i=1; i<=n; ++i)
{
scanf("%d", a+i);
sum[i] = sum[i-1] + a[i];
}
memset(dp, 7, sizeof dp);
for(int i=1; i<=n; ++i)
{
dp[i][i] = 0;
s[i][i] = i;
}
for(int w=1; w<n; ++w)
for(int i=1, j; i+w<=n; ++i)
{
j = i + w;
for(int k=s[i][j-1]; k<=s[i+1][j]; ++k)
if(dp[i][j] > dp[i][k] + dp[k+1][j] + sum[j] - sum[i-1])
{
dp[i][j] = dp[i][k] + dp[k+1][j] + sum[j] - sum[i-1];
s[i][j] = k;
}
}
printf("%d\n", dp[1][n]);
return 0;
}