题意:
有两行蘑菇,每一个格速率不一样,在采摘蘑菇时每个格子只能走一遍,并且走完所有格子,求最后能够采摘到蘑菇的最大值。
解题思路:
根据题意很容易能知道只有三种走法:
1. 2.
第三种就是这两种方式的结合,现蛇形走位,再环形走位
很明显,这个题目的解题思路就是DP
1.我们可以先预处理处理来蛇形走位,判断出来是从那一列开始进行蛇形走位:
核心代码如下:
void init_b() {
ll s = 0;
for (int i = 1; i<= n; i += 2) {
s += a1[i] * (i * 2 - 2);
s += a2[i] * (i * 2 - 1);
b[i] = s;//走U形,后面就是从上往下开始环形走位
s += a2[i + 1] * (2 * i);
s += a1[i + 1] * (2 * i + 1);
b[i + 1] = s;//走S形,后面就是从下往上开始环形走位
}
}
2.预处理走环形的长度:
由于每走一步,就会花费一分钟时间,所以说右边格子里的蘑菇会比前面格子多生长1分钟,根据这种情况,我们就假设环形开始第一步是从一开始的,我们就预处理一个后缀和。
void init_s() {
ll s = 0;
for (int i = n; i >= 1; i--) {
s += a1[i] + a2[i];
S[n - i + 1] = s;
}
}
3.再次处理环形走位
(1)假设说我们从上往下开始绕,下面的格子是最后才走到的,所以说下面所等待的时间也是需要去加上的,假设当前格子到后面还有 l 个格子,那么到达它下面格子需要走的时间就是 l * 2 + 1 分钟。
void init_c1() {
ll s = 0;
for (int i = n, l = 0; i >= 1; i--, l++) {
s += S[l];
s += a2[n - l] * (l * 2 + 1); //需要等待的时间内多生长的蘑菇
c1[i] = s;
}
}
(2)假设从下往上走时,情况和上面差不多,只是上面等待时间变长:
void init_c1() {
ll s = 0;
for (int i = n, l = 0; i >= 1; i--, l++) {
s += S[l];
s += a1[n - l] * (l * 2 + 1); //上面需要等待时间
c2[i] = s;
}
}
4.环形走位和蛇形走位收集到的蘑菇总数进行相加求一个最大值 ans:
在我们前面处理环形走位的时候,前面是先经过蛇形走位的,但是我们在计算环形走位的时候是从1开始的,所以说,每个格子的蘑菇生长实际上还应加上前面蛇形走位的时间 2 * i;
for (int i = 0; i <= n; i++) {
if (i % 2 == 1) {
ans = max(ans, b[i] + c2[i + 1] + S[n - i] * (i * 2)); // 从下往上绕
} else {
ans = max(ans, b[i] + c1[i + 1] + S[n - i] * (i * 2)); //从上往下绕
}
}
总代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
constexpr int N = 3e5 + 10, M = 5;
ll a1[N], a2[N], b[N], S[N], c1[N], c2[N];
int n;
ll ans;
void init_b() {
ll s = 0;
for (int i = 1; i<= n; i += 2) {
s += a1[i] * (i * 2 - 2);
s += a2[i] * (i * 2 - 1);
b[i] = s;
s += a2[i + 1] * (2 * i);
s += a1[i + 1] * (2 * i + 1);
b[i + 1] = s;
}
}
void init_s() {
ll s = 0;
for (int i = n; i >= 1; i--) {
s += a1[i] + a2[i];
S[n - i + 1] = s;
}
}
void init_c1() {
ll s = 0;
for (int i = n, l = 0; i >= 1; i--, l++) {
s += S[l];
s += a2[n - l] * (l * 2 + 1);
c1[i] = s;
}
}
void init_c2() {
ll s = 0;
for (int i = n, l = 0; i >= 1; i--, l++) {
s += S[l];
s += a1[n - l] * (l * 2 + 1);
c2[i] = s;
}
}
void solve() {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a1[i];
}
for (int i = 1; i <= n; i++) {
cin >> a2[i];
}
init_b();
init_s();
init_c1();
init_c2();
for (int i = 0; i <= n; i++) {
if (i % 2 == 1) {
ans = max(ans, b[i] + c2[i + 1] + S[n - i] * (i * 2));
} else {
ans = max(ans, b[i] + c1[i + 1] + S[n - i] * (i * 2));
}
}
cout << ans << "\n";
}
int main() {
int t = 1;
//cin >> t;
while (t--) {
solve();
}
}