单纯地讲思路有些难懂,这里结合一个实际例子来说明:
我们假设A = {0,1,3},n = 2,N = 10,r = 10 / (2 + 1) = 3。
i | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
---|---|---|---|---|---|---|---|---|---|---|
f(i) | 0 | 1 | 1 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |
g(i) | 0 | 0 | 0 | 1 | 1 | 1 | 2 | 2 | 2 | 3 |
|g(i) - f(i)| | 0 | 1 | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 1 |
肯定不能蠢到遍历所有的i,求|g(i) - f(i)|的和,这样做铁定超时。在做上一题(202112-1,序列查询)的时候,应该要想到需要遍历的是所有的f(i),而不是i。
那么, 对于一个给定的f(i)而言,我们怎么求这个区间里|g(i) - f(i)|之和?
以上表中f(i) = 1时为例,区间内的g(i) = {0,0}。这一段区间是比较特殊的,因为区间内的所有g(i)都小于等于这个f(i)。这时问题就比较好办了,这个区间里|g(i) - f(i)|之和等于区间里f(i)之和减去区间里g(i)之和。
如何快速求区间里f(i)之和,不用我再解释。关键是如何快速求区间里g(i)之和。我们令g(i)的前缀和为h(i),给定的f(i)区间里i的范围是[lft,rgt],那么求区间里g(i)之和其实就是求h(rgt) - h(lft - 1)。例如,求f(i) = 1时g(i)之和,就是求h(2) - h(1 - 1)。
i | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
---|---|---|---|---|---|---|---|---|---|---|
g(i) | 0 | 0 | 0 | 1 | 1 | 1 | 2 | 2 | 2 | 3 |
h(i) | 0 | 0 | 0 | 1 | 2 | 3 | 5 | 7 | 9 | 12 |
如何计算g(i)的前缀和h(i),有简便的方法。g(i)的前k(这里k从0开始)项,是由r个首项为0、项数为(k + 1) / r、公差为1的等差数列,以及(k + 1) % r个k / r构成的。举例来说,g(i)的前7项{0, 0, 0, 1, 1, 1, 2,2}是由3个首项为0、项数为2、公差为1的等差数列{0, 1},以及2个2构成的。那么我们可以知道h(i) = r * ((i + 1) / r - 1) * ((i + 1) / r) / 2 + (i + 1) % r * (i / r)。
现在,我们已经知道了给定一个f(i),如果对应区间里g(i)全部小于等于该f(i)时,怎么求这个区间里|g(i) - f(i)|之和。如果对应区间里g(i)全部大于等于该f(i)呢?那么只是改成用区间里g(i)之和减去区间里f(i)之和罢了。但是还有一种特殊情况,那就是区间里的g(i)既有小于f(i)的,也有大于f(i)的。这个时候我们直接把区间从g(i) = f(i)处一分为二,左半边的g(i)全部小于f(i),右半边的g(i)全部大于f(i),分别计算区间里|g(i) - f(i)|之和——以f(i) = 2时为例,把f(i) = 2的区间从i = 6处分为两半,那么左半边的g(i) = {1,1,1,2},右半边的g(i) = {2,2,3},分别计算这两半就可以了。
#include <bits/stdc++.h>
using namespace std;
long long h(long long i, long long r) { //求g(i)的前缀和h(i)
if (i < 0) return 0;
else return r * ((i + 1) / r - 1) * ((i + 1) / r) / 2 + (i + 1) % r * (i / r);
}
long long cal(long long fi, long long lft, long long rgt, long long r) { //给定一个f(i),计算区间里的|g(i)-f(i)|之和,前提是g(i)全部小于等于或者全部大于等于f(i)
return abs(h(rgt, r) - h(lft - 1, r) - fi * (rgt - lft + 1));
}
int main()
{
long long n, N, t;
cin >> n >> N;
vector<long long> a = {0};
for (int i = 0; i < n; ++i) {
cin >> t;
a.push_back(t);
}
a.push_back(N);
long long r = N / (n + 1), ans = 0;
for (long long fi = 0; fi <= n; ++fi) { //遍历每个f(i)
long long lft = a[fi], rgt = a[fi + 1] - 1;
if (lft / r >= fi || rgt / r <= fi) ans += cal(fi, lft, rgt, r); //如果区间内g(i)全部小于等于或者全部大于等于f(i),直接使用cal函数
else ans += cal(fi, lft, r * fi, r) + cal(fi, r * fi + 1, rgt, r); //否则将区间分成两半,分别使用cal函数
}
cout << ans << endl;
return 0;
}