题意:给定一个长度为 n n n的括号序列 S S S,一个长度为 n n n的序列 a i a_i ai。定义一段括号序列的价值:若括号序列 [ l , r ] [l,r] [l,r]为合法序列,其价值为 a r − a l a_r-a_l ar−al;若不合法,价值为 0 0 0.可将 S S S划分为若干非空子段,定义美丽度为每个子段价值之和,求最大美丽度。 n ≤ 3000000 n\le3000000 n≤3000000.
不难想到dp,设
f
i
f_i
fi表示原括号序列的前i位所能得到的最大美丽度。转移方程为
f
i
=
m
a
x
(
f
i
−
1
,
f
j
+
a
i
−
a
j
+
1
)
f_i = max(f_{i-1}, f_j+a_i-a_{j+1})
fi=max(fi−1,fj+ai−aj+1)其中
j
∈
[
1
,
i
−
2
]
j\in[1, i-2]
j∈[1,i−2],并且
[
j
+
1
,
i
]
[j+1,i]
[j+1,i]是一个合法的括号序列。
这样可以得到
O
(
n
2
)
O(n^2)
O(n2)的算法。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define For(i, a, b) for (int i = (a); i <= (b); i++)
#define FOR(i, a, b) for (int i = (a); i >= (b); i--)
const int N = 3e6 + 5;
int n, a[N];
ll dp[N];
char s[N];
int main() {
scanf("%d", &n);
scanf("%s", s + 1);
For (i, 1, n) scanf("%d", &a[i]);
memset(dp, 0xcf, sizeof(dp));
dp[0] = 0;
For (i, 1, n) {
dp[i] = dp[i - 1];
if (s[i] == '(') continue;
else {
int res = 1;
FOR(j, i - 1, 1) {
if (s[j] == '(') res -= 1;
else res += 1;
if (res < 0) break;
if (!res) dp[i] = max(dp[i], dp[j - 1] + (ll)a[i] - a[j]);
}
}
}
printf("%lld\n", dp[n]);
return 0;
}
考虑优化。
注意
a
i
a_i
ai为定值,所以找
m
a
x
{
f
j
+
a
i
−
a
j
+
1
}
max\{ f_j+a_i-a_{j+1}\}
max{fj+ai−aj+1}就是找
m
a
x
{
f
j
−
a
j
+
1
}
max\{f_j-a_{j+1}\}
max{fj−aj+1}
记
c
i
=
m
a
x
{
f
j
−
a
j
+
1
}
c_i=max\{f_j-a_{j+1}\}
ci=max{fj−aj+1},
[
j
+
1
,
i
]
[j+1,i]
[j+1,i]为合法括号序列。
那么
f
i
=
m
a
x
{
c
i
+
a
i
,
f
i
−
1
}
f_i=max\{c_i+a_i,f_{i-1}\}
fi=max{ci+ai,fi−1},只要快速计算
c
i
c_i
ci即可。
求
c
i
c_i
ci的方法比较精彩,对于一个位置在
i
i
i的右括号,用
l
s
t
i
lst_i
lsti记录能与其匹配的最近左括号位置。这步操作可以用一个类似栈的方法做到。则
c
i
=
m
a
x
{
c
l
s
t
i
−
1
,
f
l
s
t
i
−
1
−
a
l
s
t
i
}
c_i=max\{c_{lst_i-1, f_{lst_i-1} -a_{lst_i}}\}
ci=max{clsti−1,flsti−1−alsti}
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define For(i, a, b) for (int i = (a); i <= (b); i++)
#define FOR(i, a, b) for (int i = (a); i >= (b); i--)
const int N = 3e6 + 5;
int n, a[N], lst[N], stk[N];
ll f[N], c[N];
char s[N];
int main() {
scanf("%d", &n);
scanf("%s", s + 1);
For (i, 1, n) scanf("%d", &a[i]);
memset(f, 0xcf, sizeof(f));
f[0] = 0;
memset(c, 0xcf, sizeof(c));
int top = 0;
For (i, 1, n) {
if (s[i] == ')' && top) lst[i] = stk[top--];
else if (s[i] == '(') stk[++top] = i;
}
For (i, 1, n) {
f[i] = f[i - 1];
if (s[i] == ')' && lst[i]) {
c[i] = max(c[lst[i] - 1], f[lst[i] - 1] - a[lst[i]]);
f[i] = max(c[i] + a[i], f[i]);
}
}
printf("%lld\n", f[n]);
return 0;
}