题目
题解
一打眼看过去就是环形石子合并的板子题。
环形石子合并的思想就不再详述了,区间dp + 化曲为直思想
千万不要想当然地理解为就是环形石子合并稍微变了个型,代码不变。如果这么想就大错特错了。
方法一: 对应代码1
样例中的四个石子
如果直接按照石子进行合并来实现的话(见代码1):
为四个石子编号
通过手动实现几个过程,总结出转移方程:
dp[i][j]
表示从第i
个石子合并到第j
个石子的最大花费。
dp[1][4] = max{
dp[1][1] + dp[2][4] + stone[1].head*stone[2].head*stone[4].tail,
dp[1][2] + dp[3][4] + stone[1].head*stone[3].head*stone[4].tail,
dp[1][3] + dp[3][4] + stone[1].head*stone[4].head*stone[4].tail
}
好了,可以推出
dp[i][j] = max(dp[i][j], dp[i][k]+ dp[k+1][j] + stone[i].head*stone[k+1].head*stone[j].tail)
k的范围:i <= k < j
之后的代码与石子合并从dp从得到答案的方法一样。
方法二: 对应代码2
不想用结构体,就想在环形石子合并代码的基础上改改!
下面就介绍另一种方法。
为每个数编号了,不再为石子编号。
我们要类比上面的转移方程,得到新的转义方程:
dp[i][j]
表示从第i
个石子合并到第j
个石子的最大花费。
dp[1][4] = max{
dp[1][2] + dp[2][4] + num[1]*num[2]*num[4],
dp[1][3] + dp[3][4] + num[1]*num[3]*num[4]
}
可以推出
dp[i][j] = max(dp[i][j], dp[i][k]+ dp[k][j] + num[i]*num[k+1]*num[j])
k的范围:i < k < j
同时应当注意几点:
- k的范围
- 区间长度要从3开始遍历,因为当区间长度为2时,在
i
和j
的遍历中是不会被更新的。 - 千万要注意的就是从dp中找最大值的代码,这次我们要找的是区间长度为
n+1
的dp中的最大值。之所以要以长度为n+1
进行尺取,是因为上述转移方程计算出的并非环形合并的最大花费,用样例来说:四个石子,转移方程的dp[1][4]
算出的是从第一个数合并到第四个数的最大花费,但其实我们还要将第四个数与第一个数构建一个联系,这是上述转移方程无法体现的。为了解决这一问题,我们只需以n+1
进行尺取,这样就相当于将首尾沟通起来了,算出的也就是题目要求的花费了。
代码1
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 210;
ll ans, dp[N][N], a[N];
int n;
struct stone {
ll h, t;
} s[N];
int main()
{
cin>>n;
for(int i = 1;i <= n;i ++) cin>>a[i];
for(int i = 1;i < n;i ++) {
s[i].h = s[i+n].h = a[i];
s[i].t = s[i+n].t = a[i+1];
}
s[n].h = s[n+n].h = a[n];
s[n].t = s[n+n].t = a[1];
for(int m = 1;m <= 2*n;m ++) {
for(int i = 1;i+m-1 <= 2*n;i ++) {
int j = i+m-1;
for(int k = i;k < j;k ++) {
dp[i][j] = max(dp[i][j], dp[i][k] + dp[k+1][j] + s[i].h*s[k+1].h*s[j].t);
}
}
}
for(int i = 1;i <= n;i ++) ans = max(ans, dp[i][i+n-1]);
printf("%lld", ans);
return 0;
}
代码2
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 210;
ll ans, dp[N][N], a[N];
int n;
int main()
{
cin>>n;
for(int i = 1;i <= n;i ++) cin>>a[i], a[i+n] = a[i];
for(int m = 3;m <= 2*n;m ++) // !
for(int i = 1;i+m-1 <= 2*n;i ++) {
int j = i+m-1;
for(int k = i+1;k < j;k ++) // !
dp[i][j] = max(dp[i][j], dp[i][k] + dp[k][j] + a[i]*a[k]*a[j]);
}
for(int i = 1;i <= n;i ++) ans = max(ans, dp[i][i+n]); // !
printf("%lld", ans);
return 0;
}