题目:
注:这道题洛谷3648有SPJ,要求输出方案。BZOJ3675数据组数较多但不要求输出方案。
分析:
这可能是我第三次重学斜率优化了……好菜啊
这道题首先一看就是个DP。稍微推一推类似下面这种式子就会发现事实上结果和切的顺序无关
a ( b + c ) + b c = a b + c ( a + b ) = a b + a c + b c a(b+c)+bc=ab+c(a+b)=ab+ac+bc a(b+c)+bc=ab+c(a+b)=ab+ac+bc
那么就可以用 f [ i ] [ j ] f[i][j] f[i][j]表示切了 j j j次,最右一次在 i i i后面切的最大值。用 s u m [ i ] sum[i] sum[i]表示原序列前 i i i个数之和,那么就有了这个DP方程(假设在 i i i后面是第一次切割):
f [ i ] [ j ] = m a x ( f [ k ] [ j − 1 ] + s u m [ k ] × ( s u m [ i ] − s u m [ k ] ) ) ( 0 < k < i ) f[i][j]=max(f[k][j-1]+sum[k]\times (sum[i]-sum[k]))(0<k<i) f[i][j]=max(f[k][j−1]+sum[k]×(sum[i]−sum[k]))(0<k<i)
其中 j j j这一维可以滚动存储。转移的时候顺便记一下从什么地方转移过来就行。这个时间复杂度是 O ( n 2 k ) O(n^2k) O(n2k)。丢一部分代码。
int work()
{
read(n), read(k);
for (int i = 1; i <= n; i++)
read(sum[i]), sum[i] += sum[i - 1];
for (int i = 2; i <= k; i++)
for (int j = n; j > 0; j--)
for (int g = 1; g < j; g++)
{
if (dp[j] < dp[g] + sum[g] * (sum[j] - sum[g]))
{
dp[j] = dp[g] + sum[g] * (sum[j] - sum[g]);
pre[i][j] = g;
}
}
int st = 0;
ll ans = 0;
for (int i = 1; i <= n; i++)
if (ans < dp[i] + sum[i] * (sum[n] - sum[i]))
{
ans = dp[i] + sum[i] * (sum[n] - sum[i]);
st = i;
}
write(ans), putchar('\n'), write(st);
for (int i = k; i > 1; i--)
putchar(' '), write(st = pre[i][st]);
return 0;
}
实测洛谷上有
64
64
64分。64分够了,本文结束。
考虑斜率优化。以下的讨论均为在 j j j(切的次数)固定的情况下,为方便说明,用 f [ i ] f[i] f[i]表示 f [ i ] [ j ] f[i][j] f[i][j], g [ i ] g[i] g[i]表示 f [ i ] [ j − 1 ] f[i][j-1] f[i][j−1]。再来看这个式子。
f [ i ] = m a x ( g [ j ] + s u m [ j ] × ( s u m [ i ] − s u m [ j ] ) ) ( 0 < j < i ) f[i]=max(g[j]+sum[j]\times (sum[i]-sum[j]))(0<j<i) f[i]=max(g[j]+sum[j]×(sum[i]−sum[j]))(0<j<i)
考虑对于 g [ j ] g[j] g[j]和 g [ k ] ( j < k < i ) g[k](j<k<i) g[k](j<k<i),如果从 g [ j ] g[j] g[j]转移到 f [ i ] f[i] f[i]比从 g [ k ] g[k] g[k]转移更优,那么一定满足:
g [ j ] + s u m [ j ] × ( s u m [ i ] − s u m [ j ] ) > g [ k ] + s u m [ k ] × ( s u m [ i ] − s u m [ k ] ) g[j]+sum[j]\times (sum[i]-sum[j])>g[k]+sum[k]\times (sum[i]-sum[k]) g[j]+sum[j]×(sum[i]−sum[j])>g[k]+sum[k]×(sum[i]−sum[k])
进行一些变换,得到:
g [ j ] − s u m [ j ] 2 > g [ k ] − s u m [ k ] 2 + s u m [ i ] × ( − s u m [ j ] + s u m [ k ] ) g[j]-sum[j]^2>g[k]-sum[k]^2+sum[i]\times (-sum[j]+sum[k]) g[j]−sum[j]2>g[k]−sum[k]2+sum[i]×(−sum[j]+sum[k])
由于前缀和单调不降, ( − s u m [ j ] + s u m [ k ] ) (-sum[j]+sum[k]) (−sum[j]+sum[k])是非负的,除到右边得到(暂时不考虑为 − s u m [ j ] + s u m [ k ] = 0 -sum[j]+sum[k]=0 −sum[j]+sum[k]=0的特殊情况):
( g [ j ] − s u m [ j ] 2 ) − ( g [ k ] − s u m [ k ] 2 ) − s u m [ j ] + s u m [ k ] > s u m [ i ] \frac{(g[j]-sum[j]^2)-(g[k]-sum[k]^2)}{-sum[j]+sum[k]}>sum[i] −sum[j]+sum[k](g[j]−sum[j]2)−(g[k]−sum[k]2)>sum[i]
可以看出左侧的式子很“工整”。把 g [ j ] − s u m [ j ] 2 g[j]-sum[j]^2 g[j]−sum[j]2看作点 j j j的纵坐标, − s u m [ j ] -sum[j] −sum[j]看作点 j j j的横坐标,则左侧就是 j j j和 k k k两点之间的斜率。对于横坐标相等的两点,斜率根据纵坐标的符号视作正无穷或负无穷。
接下来阅读前,请时刻牢记:对于 j < k < i j<k<i j<k<i,如果 j j j和 k k k之间的斜率大于 s u m [ i ] sum[i] sum[i],则 j j j比 k k k优
同时还有这句话的反面:对于 j < k < i j<k<i j<k<i,如果 j j j和 k k k之间的斜率不大于 s u m [ i ] sum[i] sum[i],则 j j j比 k k k劣
由于 s u m [ i ] sum[i] sum[i]是单调不降的,所以满足决策单调性,即:如果对于 f [ i ] f[i] f[i],从 g [ k ] g[k] g[k]转移比从 g [ j ] g[j] g[j]优 ( j < k ) (j<k) (j<k),则对于 f [ i ′ ] ( i < i ′ ≤ n ) f[i'](i<i'\leq n) f[i′](i<i′≤n), g [ j ] g[j] g[j]不可能是最优的(显然, j j j和 k k k间的斜率是不受 i i i影响的,而如果此时斜率已经小于等于 s u m [ i ] sum[i] sum[i]了,则以后也不可能大于 s u m [ i ] sum[i] sum[i],所以 j j j以后永远不可能比 k k k优) 。
那么可以维护一个斜率递增的单调队列。如果感到这部分难以理解,请再反复看上面三句加粗的话。当决策 f [ i ] f[i] f[i]时,弹出单调队列首的若干元素,直到只剩一个元素或前两个元素的斜率大于 s u m [ i ] sum[i] sum[i]。当尝试插入 i i i点时,弹出队尾的若干元素,直到只剩一个元素或队尾与 i i i的斜率比队尾后两个元素大。在几何意义上,这是一个下凸包。读者可以画图理解。
代码:
#include <cstdio>
#include <algorithm>
#include <cctype>
#include <cstring>
using namespace std;
namespace zyt
{
template<typename T>
inline void read(T &x)
{
bool f = false;
char c;
x = 0;
do
c = getchar();
while (c != '-' && !isdigit(c));
if (c == '-')
f = true, c = getchar();
do
x = x * 10 + c - '0', c = getchar();
while (isdigit(c));
if (f)
x = -x;
}
template<typename T>
inline void write(T x)
{
static char buf[20];
char *pos = buf;
if (x < 0)
putchar('-'), x = -x;
do
*pos++ = x % 10 + '0';
while (x /= 10);
while (pos > buf)
putchar(*--pos);
}
typedef long long ll;
typedef long double ld;
const int N = 1e5 + 10, K = 210;
const ll LINF = 0x3f3f3f3f3f3f3f3fLL;
int n, k, now;
ll sum[N], dp[2][N];
int pre[K][N];
inline ll sq(const ll x)
{
return x * x;
}
inline ll y(const int i)
{
return dp[now ^ 1][i] - sq(sum[i]);
}
inline ll x(const int i)
{
return -sum[i];
}
inline ld ratio(const int i, const int j)
{
if (x(i) == x(j))
return (y(i) - y(j) > 0) ? LINF : -LINF;
else
return (ld)(y(i) - y(j)) / (x(i) - x(j));
}
int work()
{
static int q[N];
read(n), read(k);
for (int i = 1; i <= n; i++)
read(sum[i]), sum[i] += sum[i - 1];
for (int i = 2; i <= k; i++)
{
now = i & 1;
int h = 0, t = 1;
q[0] = 0;
for (int j = 1; j <= n; j++)
{
while (h + 1 < t && ratio(q[h], q[h + 1]) <= sum[j])
++h;
dp[now][j] = dp[now ^ 1][q[h]] + sum[q[h]] * (sum[j] - sum[q[h]]);
pre[i][j] = q[h];
while (h + 1 < t && ratio(q[t - 2], q[t - 1]) >= ratio(q[t - 1], j))
--t;
q[t++] = j;
}
}
int st = 0;
ll ans = 0;
for (int i = 1; i <= n; i++)
if (ans < dp[now][i] + sum[i] * (sum[n] - sum[i]))
{
ans = dp[now][i] + sum[i] * (sum[n] - sum[i]);
st = i;
}
write(ans), putchar('\n'), write(st);
for (int i = k; i > 1; i--)
putchar(' '), write(st = pre[i][st]);
return 0;
}
}
int main()
{
return zyt::work();
}