一道简单DP优化调了好久qwq
首先分析题目,发现每次从一边取贝壳是完全没用的,此题本质就是将区间分成数个区间,使区间价值和最大。
可以发现一个性质,那就是最优解的每个区间的两端点一定相同且为选取的\(s_0\)。因为如果区间两端点的值不同,那么完全可以将多余的值分为另一个区间使价值和更大。
所以可以写出简单的dp式:
\(dp[i] = max(dp[j-1] + s[i] * (sum[i] - sum[j]+1)^2) \quad (s[j] == s[i])\)
其中\(sum[i]\)为1…i中\(s[i]\)的个数,可以简单的\(O(1)\)维护,所以总复杂度\(O(n^2)\)
观察单调性,发现对于决策a,b\((a<b)\)如果在k处a比b优,那么在k之后a也一定比b优,而k可以通过二分\(O(log_n)\)求出
所以可以使用单调栈维护最优决策,对于每次决策,如果栈顶不优了就弹出栈顶。同时,为了维护栈的单调性,每次入栈z时,如果z与栈顶元素的分界点\(k_1\)比栈顶与栈顶的下一个元素的分界点\(k_2\)靠后,那么便可以弹出栈顶元素。
/**************************************************************
Problem: 4709
User: liuxinyuan
Language: C++
Result: Accepted
Time:328 ms
Memory:3144 kb
****************************************************************/
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <stack>
#include <vector>
#define gc getchar
#define il inline
#define re register
#define LL long long
#define mid(l, r) (((l) + (r)) >> 1)
#define sqr(x) (1ll * (x) * 1ll * (x))
#define m_p(x, y) make_pair(x, y)
using namespace std;
template <typename T>
void rd(T &s)
{
s = 0;
bool p = 0;
char ch;
while (ch = gc(), p |= ch == '-', ch < '0' || ch > '9')
;
while (s = s * 10 + ch - '0', ch = gc(), ch >= '0' && ch <= '9')
;
s *= (p ? -1 : 1);
}
template <typename T, typename... Args>
void rd(T &s, Args &... args)
{
rd(s);
rd(args...);
}
const int MAXM = 10005;
const int MAXN = 100050;
vector<int> sta[MAXM];
LL dp[MAXN];
int s[MAXN];
int cnt[MAXM], sum[MAXN];
int n;
il LL cal(int x, int y)
{
return dp[x - 1] + s[x] * 1ll * y * 1ll * y;
}
int lower(int a, int b)
{
int v = s[a];
int l = 1, r = cnt[v], ans = cnt[v] + 1;
while (l <= r)
{
int m = mid(l, r);
if (cal(a, m - sum[a] + 1) >= cal(b, m - sum[b] + 1))
ans = m,
r = m - 1;
else
l = m + 1;
}
return ans;
}
int main()
{
int v;
rd(n);
for (int i = 1; i <= n; ++i)
{
rd(s[i]);
sum[i] = ++cnt[s[i]];
}
for (int i = 1; i <= n; ++i)
{
v = s[i];
while (sta[v].size() >= 2 && lower(sta[v][sta[v].size() - 2], sta[v][sta[v].size() - 1]) < lower(sta[v][sta[v].size() - 1], i))
sta[v].pop_back();
sta[v].push_back(i);
while (sta[v].size() >= 2 && lower(sta[v][sta[v].size() - 2], sta[v][sta[v].size() - 1]) <= sum[i])
sta[v].pop_back();
dp[i] = cal(sta[v][sta[v].size() - 1], sum[i] - sum[sta[v][sta[v].size() - 1]] + 1);
// cout << i << " " << sta[v][sta[v].size() - 1] << endl;
}
// for (int i = 1; i <= n; ++i)
printf("%lld", dp[n]);
return 0;
}