给定两个数组x,y,每次从数组x顺序取一个数a,从数组y的头或者尾巴取一个数b,求所有a*b的最大和。
举例说明:
x= {1,2,3},y={1,2,1};
依次从x取出的数为 1 ,2,3。
依次从y 取出的数为 1 ,1,2。
输出结果为 1*1+2*1+3*2 = 9;
刚开始想到的就是暴力枚举所有情况,然后再剪枝,结果稳稳地超时了。
代码如下
package com.company.nothing;
import java.util.*;
public class Test {
private int ans = 0;
private Map<String, Integer> map;
public int getMaxValue(int[] nums, int[] values) {
// write code here
map = new HashMap<>();
find(0, nums, 0, nums.length - 1, values, 0);
return ans;
}
private void find(int sum, int[] nums, int l, int r, int[] values, int x) {
if (l > r) {
ans = Math.max(ans, sum);
return;
}
String key = l + "--" + r;
Integer old = map.get(key);
if (old != null && old.intValue() >= sum) {
return;
} else {
map.put(key, sum);
}
find(sum + nums[l] * values[x], nums, l + 1, r, values, x + 1);
find(sum + nums[r] * values[x], nums, l, r - 1, values, x + 1);
}
}
后来经人指点采用区间dp。
package com.company.nothing;
import java.util.*;
public class Test {
public int getMaxValue2(int[] y, int[] x) {
// 动态转移方程 l和r代表当前构造方案的区间
// dp[l][r] = max( dp[l+1][r]+x[a]*y[l], dp[l][r-1]+x[a]*y[r]); a=l+n-r
int[][] dp = new int[x.length][x.length];
for (int i = 0; i < x.length; i++) {
for (int j = 0; j < dp[i].length; j++) {
dp[i][j] = 0;
}
}
//每个数只剩一个时,x数组一定是取到最后一个了
for (int i = 0; i < x.length; i++) {
dp[i][i] = y[i] * x[x.length - 1];
}
// len从1开始向外拓展,一开始没想到这层,只用了两个循环,猝。
// 向外扩展时需保证当前dp[l][r]已经是最优解,这样更新才有意义,所以要限制每次更新的长度。
for (int len = 1; len <= x.length; len++) {
for (int l = 0; l < x.length; l++) {
//这边可以优化成 r = l+len-1,因为重复计算了,这样时间优化为o(n*n)
for (int r = l; r-l+1 <= len && r < x.length; r++) {
if (l - 1 >= 0) {
dp[l - 1][r] = Math.max(dp[l - 1][r],
dp[l][r] + y[l - 1] * x[x.length - 1 - (r + 1 - l)]);
}
if (r + 1 < x.length) {
dp[l][r + 1] = Math.max(dp[l][r + 1],
dp[l][r] + y[r + 1] * x[x.length - 1 - (r + 1 - l)]);
}
}
}
}
return dp[0][x.length - 1];
}
}
本质上就是先算出区间长度为1的结果,之后合并成2,再合并成3,以此类推。