算是比较经典的区间DP,比较重要的思路是,把未来的花费放到现在计算。
一开始写了个空间O(n^2)的记忆化搜索,结果被卡内存了,最后换成循环了...
/* Forgive me Not */
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 3001, inf = 0x3f3f3f3f;
int n, s, pos[maxn], dp[maxn][2][2];
inline int iread() {
int f = 1, x = 0; char ch = getchar();
for(; ch < '0' || ch > '9'; ch = getchar()) f = ch == '-' ? -1 : 1;
for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
return f * x;
}
int main() {
n = iread(); s = iread();
for(int i = 1; i <= n; i++) pos[i] = iread();
pos[++n] = s;
sort(pos + 1, pos + 1 + n);
s = lower_bound(pos + 1, pos + 1 + n, s) - pos;
for(int i = 1; i <= n; i++) dp[i][1][0] = dp[i][1][1] = inf;
dp[s][1][0] = dp[s][1][1] = 0;
for(int i = 2; i <= n; i++) for(int l = 1; l <= n - i + 1; l++) {
int r = l + i - 1;
dp[l][i & 1][1] = min(dp[l + 1][~i & 1][1] + (n - i + 1) * (pos[l + 1] - pos[l]), dp[l + 1][~i & 1][0] + (n - i + 1) * (pos[r] - pos[l]));
dp[l][i & 1][0] = min(dp[l][~i & 1][0] + (n - i + 1) * (pos[r] - pos[r - 1]), dp[l][~i & 1][1] + (n - i + 1) * (pos[r] - pos[l]));
}
printf("%d\n", min(dp[1][n & 1][0], dp[1][n & 1][1]));
return 0;
}
记忆化搜索:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 3001, inf = 0x3f3f3f3f;
int n, s, pos[maxn], dp[maxn][maxn][2];
inline int iread() {
int f = 1, x = 0; char ch = getchar();
for(; ch < '0' || ch > '9'; ch = getchar()) f = ch == '-' ? -1 : 1;
for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
return f * x;
}
inline int iabs(int x) {
return x < 0 ? -x : x;
}
inline int dfs(int l, int r, bool left) {
if(l == r) return l == s ? 0 : inf;
if(~dp[l][r][left]) return dp[l][r][left];
int res = 0;
if(left) res = min(dfs(l + 1, r, 1) + (n - r + l) * iabs(pos[l + 1] - pos[l]), dfs(l + 1, r, 0) + (n - r + l) * iabs(pos[r] - pos[l]));
else res = min(dfs(l, r - 1, 0) + (n - r + l) * iabs(pos[r] - pos[r - 1]), dfs(l, r - 1, 1) + (n - r + l) * iabs(pos[r] - pos[l]));
return dp[l][r][left] = res;
}
int main() {
n = iread(); s = iread();
for(int i = 1; i <= n; i++) pos[i] = iread();
pos[++n] = s;
sort(pos + 1, pos + 1 + n);
s = lower_bound(pos + 1, pos + 1 + n, s) - pos;
memset(dp, -1, sizeof(dp));
printf("%d\n", min(dfs(1, n, 0), dfs(1, n, 1)));
return 0;
}