代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e5 + 10;
int n, m, flag[N], a[N], b[N], fa[N], ans, maxn;
// 并查集
int find(int x)
{
return x == fa[x] ? x : fa[x] = find(fa[x]);
}
void merge(int x, int y)
{
x = find(x);
y = find(y);
if (x != y)
{
fa[x] = y;
a[y] += a[x];
}
}
signed main()
{
cin >> n;
for (int i = 1; i <= n; i++)
{
cin >> a[i];
fa[i] = i;
}
for (int i = 1; i <= n; i++)
{
cin >> b[i];
}
for (int i = n; i >= 2; i--)
{
flag[b[i]] = 1;
if (flag[b[i] - 1] && !flag[b[i] + 1])
merge(b[i], b[i] - 1);
if (!flag[b[i] - 1] && flag[b[i] + 1])
merge(b[i], b[i] + 1);
if (flag[b[i] - 1] && flag[b[i] + 1])
merge(b[i], b[i] - 1), merge(b[i], b[i] + 1);
maxn = max(maxn, a[find(b[i])]);
ans += maxn;
}
cout << ans;
return 0;
}