CS101 2021Fall PA1,2 题解
关于PA (Programming Assignment)
鉴于PA的题目都没有在课上或讨论课上讲解过,piazza上也没有发过题解,GKxx打算简单整理一下这些题目的题解。对于其中的部分题目,我会尽可能多写几个能想到的算法,并提供一部分代码。
水平有限,这份题解仅供参考,如有错误欢迎指出。如果对某些题目有同学有更好的做法,请直接发,不要联系我,我懒得再看了。
建议以后的CS101可以在每次PA结束以后开放一个题解展示平台,并安排一两名助教审核提交的题解,可以参考洛谷的题解审核标准。大多数OIer都有写题解的习惯,我相信只要开放这个平台就会有人主动交题解的,都不用助教或老师自己干。
一些约定
若无特殊说明,所有下标均从
1
1
1开始;但如果是标准库容器或std::string
的下标,则从
0
0
0开始。
代码是C++14的。
PA1
1001
题意:给定二维坐标平面上的 N N N个点和两个圆心 O 1 , O 2 O_1,O_2 O1,O2,需要求出两个半径 R 1 , R 2 R_1,R_2 R1,R2,使得每一个点都至少被一个圆覆盖到,并最小化 R 1 2 + R 2 2 R_1^2+R_2^2 R12+R22。 N ⩽ 80000 N\leqslant 80000 N⩽80000,坐标数值的绝对值在 1000 1000 1000以内。
每个点都会被这两个圆中的某一个覆盖(同时被两个圆覆盖的情况,你可以认为它只归其中一个圆管,另一个圆“碰巧”覆盖到它而已),所以圆的半径就由它需要覆盖的那些点中离圆心最远的那个决定。我们考虑将所有点按照与 O 1 O_1 O1的距离从大到小排序,枚举圆 O 1 O_1 O1所覆盖的最远的那个点是谁。假设圆 O 1 O_1 O1覆盖的最远的点是 P i P_i Pi,这意味着 P i , ⋯ , P n P_i,\cdots,P_n Pi,⋯,Pn都已经被圆 O 1 O_1 O1覆盖了,而圆 O 2 O_2 O2的任务是覆盖 P 1 , ⋯ , P i − 1 P_1,\cdots,P_{i-1} P1,⋯,Pi−1,我们花 O ( n ) O(n) O(n)的时间求出这些点中与 O 2 O_2 O2的距离最大值,这样我们就得到了一个 O ( n 2 ) O(n^2) O(n2)的做法。
如何优化?注意,我们每次都需要求 f ( i ) = max { d ( j ) , 1 ⩽ j < i } f(i)=\max\{d(j),1\leqslant j<i\} f(i)=max{d(j),1⩽j<i},其中 d ( j ) = dist ( P j , O 2 ) d(j)=\operatorname{dist}(P_j,O_2) d(j)=dist(Pj,O2)为 P j P_j Pj与 O 2 O_2 O2的距离。但实际上 f ( i ) = max { f ( i − 1 ) , d ( i ) } f(i)=\max\{f(i-1),d(i)\} f(i)=max{f(i−1),d(i)},它可以根据上一次的答案 O ( 1 ) O(1) O(1)地得出,不需要花 O ( n ) O(n) O(n)的时间重新求。这个 f f f不需要开数组存,因为你每次只使用上一次的答案,反复更新一个变量就好了。
这样在排序之后时间复杂度就为 O ( n ) O(n) O(n),而排序是整个算法的时间复杂度瓶颈,用常见的归并排序或快速排序可以做到 O ( n log n ) O(n\log n) O(nlogn)。
#include <cmath>
#include <cstdio>
#include <functional> // for std::less
#include <iterator> // for std::iterator_traits, std::distance, std::next
namespace gkxx {
template <typename ForwardIterator, typename Less>
void __inplace_merge(
ForwardIterator begin, ForwardIterator mid, ForwardIterator end,
typename std::iterator_traits<ForwardIterator>::difference_type dist,
Less less) {
ForwardIterator i = begin, j = mid;
using value_type = typename std::iterator_traits<ForwardIterator>::value_type;
value_type *tmp = new value_type[dist](), *k = tmp;
while (i != mid && j != end) {
if (less(*i, *j))
*k++ = *i++;
else
*k++ = *j++;
}
while (i != mid)
*k++ = *i++;
while (j != end)
*k++ = *j++;
k = tmp;
while (begin != end)
*begin++ = *k++;
delete[] tmp;
}
template <typename ForwardIterator, typename Less>
void merge_sort(ForwardIterator begin, ForwardIterator end, Less less) {
auto dist = std::distance(begin, end);
if (dist <= 1)
return;
ForwardIterator mid = std::next(begin, dist / 2);
merge_sort(begin, mid, less);
merge_sort(mid, end, less);
__inplace_merge(begin, mid, end, dist, less);
}
template <typename ForwardIterator>
inline void merge_sort(ForwardIterator begin, ForwardIterator end) {
merge_sort(begin, end, std::less<void>());
}
} // namespace gkxx
constexpr int maxn = 80007;
struct Point {
int x, y;
};
Point O1, O2, p[maxn];
int n;
inline int sqr_dist(const Point &a, const Point &b) {
int dx = a.x - b.x, dy = a.y - b.y;
return dx * dx + dy * dy;
}
int main() {
scanf("%d%d%d%d", &O1.x, &O1.y, &O2.x, &O2.y);
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d%d", &p[i].x, &p[i].y);
gkxx::merge_sort(p + 1, p + n + 1,
[](const Point &a, const Point &b) -> bool {
return sqr_dist(a, O1) > sqr_dist(b, O1);
});
// First consider the case where R1 = 0.
int ans = 0;
for (int i = 1; i <= n; ++i) {
int sqr_r2 = sqr_dist(p[i], O2);
if (sqr_r2 > ans)
ans = sqr_r2;
}
int prev_max = 0;
for (int i = 1; i <= n; ++i) {
int sqr_r1 = sqr_dist(p[i], O1);
if (sqr_r1 + prev_max < ans)
ans = sqr_r1 + prev_max;
int d = sqr_dist(p[i], O2);
if (d > prev_max)
prev_max = d;
}
printf("%d\n", ans);
return 0;
}
提醒一下:写快速排序要随机选pivot,你写取三个数的中位数的那种选法也还好,但有些人每次都取最后一个数,就很容易被卡成worst case。
1002
题意:给定 n n n个字符串,求出最长公共前缀(LCP)和最长公共后缀(LCS)。“公共前缀”和“公共后缀”的长度必须至少为 2 2 2。长度相等的情况下找出现次数最多的,出现次数仍然相等的情况下找字典序最小的。 n ⩽ 50000 n\leqslant 50000 n⩽50000,单个字符串长度不超过 10 10 10,字符集为小写字母。
字符串长度不超过 10 10 10,意味着你可以肆无忌惮地遍历字符串,拷贝字符串,而不用担心时间开销,所以搞什么Trie树搞什么Hash table,直接暴力啊。
这题做法五花八门,我说一个我自己写的。先说LCP:把所有字符串按字典序从小到大排序,则答案一定是某两个相邻的字符串的LCP。求两个字符串的LCP只要暴力枚举前缀判断是否相等即可。我们为每一对相邻的字符串 s i , s i + 1 s_i,s_{i+1} si,si+1求出LCP,记为 t i t_i ti,则我们只要找出 t 1 , ⋯ , t n − 1 t_1,\cdots,t_{n-1} t1,⋯,tn−1这些串中最长的那个,长度相等的情况下出现次数最多的那个,出现次数仍相等的情况下字典序最小的那个。
实现中,我们可以记录每一个LCP出现的次数,如果当前
s
i
s_i
si与
s
i
+
1
s_{i+1}
si+1的LCP恰好和
s
i
−
1
s_{i-1}
si−1与
s
i
s_i
si的LCP相同(长度相等就可以判断相同),就递增上一个LCP的出现次数。我们可以用一个std::pair<std::string, int>
来保存一个LCP及其出现次数,那么最终只要在所有pair
中找一个最优的,这跟在给定数组中找最大值没有本质区别。
对于LCS,我们可以把所有串都反过来,然后求LCP。但务必注意,比较字典序的时候要用反转之前的串比较。
1003
题意:给定 n n n个数 a 1 , ⋯ , a n a_1,\cdots,a_n a1,⋯,an,计算有多少对 ( i , j ) (i,j) (i,j)满足 i , j ∈ [ 1 , n ] ∩ Z , i < j i,j\in[1,n]\cap\Z,i<j i,j∈[1,n]∩Z,i<j,且 ∀ k ∈ ( i , j ) ∩ Z , a k ⩽ min { a i , a j } \forall k\in(i,j)\cap\Z,a_k\leqslant\min\{a_i,a_j\} ∀k∈(i,j)∩Z,ak⩽min{ai,aj}。(即给出 n n n位同学的身高,请问有多少对同学能互相看见)
算法1:单调栈
考虑枚举右边的同学 j j j,计算左边有多少个同学 i i i能和 j j j互相看见。枚举 j j j的复杂度已经 O ( n ) O(n) O(n),所以计算 i i i的数量的过程必须足够快。有一个显然与此有关的问题:假设 f ( j ) f(j) f(j)表示 j j j左边第一个比 j j j高的同学的编号,那么显然在 f ( j ) f(j) f(j)左边的同学都没法看见 j j j。你能否快速求出 f ( j ) f(j) f(j)?你不能从 j − 1 j-1 j−1到 1 1 1一个一个检查,因为这会使得求单个 f ( j ) f(j) f(j)的复杂度为 O ( j ) O(j) O(j),总复杂度就为 O ( n 2 ) O(n^2) O(n2),太慢。
记
f
k
(
j
)
=
f
(
f
k
−
1
(
j
)
)
,
f
0
(
j
)
=
j
f^k(j)=f(f^{k-1}(j)),f^0(j)=j
fk(j)=f(fk−1(j)),f0(j)=j。假如已知
f
(
1
)
,
⋯
,
f
(
j
−
1
)
f(1),\cdots,f(j-1)
f(1),⋯,f(j−1),我们发现
f
(
j
)
∈
{
j
−
1
,
f
(
j
−
1
)
,
f
2
(
j
−
1
)
,
⋯
}
f(j)\in\left\{j-1,f(j-1),f^2(j-1),\cdots\right\}
f(j)∈{j−1,f(j−1),f2(j−1),⋯}(为什么?),并且
a
j
−
1
<
a
f
(
j
−
1
)
<
a
f
2
(
j
−
1
)
<
⋯
a_{j-1}<a_{f(j-1)}<a_{f^2(j-1)}<\cdots
aj−1<af(j−1)<af2(j−1)<⋯。我们可以用一个栈
s
s
s来存储这些位置:从栈顶到栈底依次为
j
−
1
,
f
(
j
−
1
)
,
f
2
(
j
−
1
)
,
⋯
j-1,f(j-1),f^2(j-1),\cdots
j−1,f(j−1),f2(j−1),⋯,那么查询
f
(
j
)
f(j)
f(j)就只要在
s
s
s中从栈顶开始逐个检查即可。下图是
j
=
11
j=11
j=11时的情况,蓝色的柱子就是每个同学,红色折线连接的点就是这个栈所保存的位置编号。
当你求出了
f
(
j
)
f(j)
f(j)之后,为了下一次求
f
(
j
+
1
)
f(j+1)
f(j+1),你还要正确地维护这个栈,也就是说得把这个栈变成
j
,
f
(
j
)
,
f
2
(
j
)
,
⋯
j,f(j),f^2(j),\cdots
j,f(j),f2(j),⋯。如何维护?我们可以逐个检查栈顶元素
f
k
(
j
−
1
)
,
k
=
0
,
1
,
⋯
f^k(j-1),k=0,1,\cdots
fk(j−1),k=0,1,⋯,只要
a
f
k
(
j
−
1
)
⩽
a
j
a_{f^k(j-1)}\leqslant a_j
afk(j−1)⩽aj就把
f
k
(
j
−
1
)
f^k(j-1)
fk(j−1)弹出,直到栈顶所标示的位置上的同学比
j
j
j高(这也恰好就是
f
(
j
)
f(j)
f(j))就停止,然后把
j
j
j放到栈顶即可。在上图的例子中,我们就要把
10
10
10和
8
8
8两位同学弹栈,然后将
11
11
11加入栈顶。
由于
a
j
<
a
f
(
j
)
<
a
f
2
(
j
)
<
⋯
a_j<a_{f(j)}<a_{f^2(j)}<\cdots
aj<af(j)<af2(j)<⋯,这个栈被称为单调栈。以下是求
f
f
f数组的代码:
top = 0;
for (int j = 1; j <= n; ++j) {
while (top > 0 && a[s[top]] <= a[j])
--top;
f[j] = s[top];
s[++top] = j;
}
这时有人就要问了:每次都这样暴力地从顶向下检查,最坏情况下岂不是要检查
O
(
n
)
O(n)
O(n)个元素,那总复杂度不就
O
(
n
2
)
O(n^2)
O(n2)了吗?但是请注意:你的算法的复杂度取决于执行次数最多的那条语句被执行了多少次,在这里其实也就是弹栈操作--top;
,然而
1
1
1到
n
n
n的每一个元素都只会进栈一次,也就至多出栈一次,所以弹栈不会超过
n
n
n次,总的复杂度就是
O
(
n
)
O(n)
O(n),虽然它看起来是个双重循环。
如果所有的
a
i
a_i
ai都不相等,我们惊奇地发现:那些能和
j
j
j互相看见的同学正是
j
−
1
,
f
(
j
−
1
)
,
f
2
(
j
−
1
)
,
⋯
j-1,f(j-1),f^2(j-1),\cdots
j−1,f(j−1),f2(j−1),⋯,也就是在求
f
(
j
)
f(j)
f(j)的时候被弹栈的那些位置,所以我们在弹栈的时候顺带统计一下即可。但是题目并不保证所有
a
i
a_i
ai都不相等,这时会出现一个问题:假设
f
k
(
j
−
1
)
f^k(j-1)
fk(j−1)和
j
j
j能看见,而在
f
k
+
1
(
j
−
1
)
f^{k+1}(j-1)
fk+1(j−1)与
f
k
(
j
−
1
)
f^k(j-1)
fk(j−1)之间存在一个和
f
k
(
j
−
1
)
f^k(j-1)
fk(j−1)一样高的同学
t
t
t,那么
t
t
t不会存在于栈中,但是
t
t
t也能和
j
j
j看见。为了让这样的
t
t
t也存在于栈中,你得把弹栈的条件由
a
f
k
(
j
−
1
)
⩽
a
j
a_{f^k(j-1)}\leqslant a_j
afk(j−1)⩽aj改成
a
f
k
(
j
−
1
)
<
a
j
a_{f^k(j-1)}<a_j
afk(j−1)<aj,也就是a[s[top]] < a[i]
。但是在统计答案的时候,仍然是找到严格高于
j
j
j的同学为止,如果你仍然从栈顶开始一个一个找,就会让复杂度变成
O
(
n
2
)
O(n^2)
O(n2),因为最坏情况下所有同学的高度都相等,所有同学入栈了之后都不会出栈,每次查找都得一直找到第一个。(实测这种做法可以通过,说明数据比较水)
一种解决方法是使用二分查找。二分是一个极其重要的思想,但整个CS101课程中始终没有提及。二分查找的复杂度为 O ( log n ) O(\log n) O(logn),那么总复杂度即为 O ( n log n ) O(n\log n) O(nlogn)。查找过程中要小心 f ( j ) f(j) f(j)不存在的情况,即 j j j左边的所有同学都不比 j j j高,我的处理方式是假定 a 0 = + ∞ a_0=+\infty a0=+∞。代码如下:
#include <climits>
#include <cstdio>
constexpr int maxn = 5e5 + 7;
constexpr int inf = 1e9;
int a[maxn], n;
int stk[maxn], top;
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", a + i);
int ans = 0;
a[0] = inf;
stk[++top] = 1;
for (int i = 2; i <= n; ++i) {
int left = 0, right = top, pos;
while (left <= right) {
int mid = (left + right) / 2;
if (a[stk[mid]] <= a[i])
right = mid - 1;
else {
pos = mid;
left = mid + 1;
}
}
if (!pos)
++pos;
int delta = top - pos + 1;
ans += delta;
while (a[stk[top]] < a[i] && top)
--top;
stk[++top] = i;
}
printf("%d\n", ans);
return 0;
}
第二种解决方法是考虑给栈里高度相等的同学记录重复次数,在入栈的时候如果发现与栈顶一样高就不要入栈,而是递增其重数。但实际上你需要记录的是重数的前缀和,因为查询答案时需要计算几个同学的重数之和,可以用两个前缀和相减得出。这样可以做到 O ( n ) O(n) O(n),代码如下:
#include <climits>
#include <cstdio>
constexpr int maxn = 5e5 + 7;
constexpr int inf = 1e9;
int a[maxn], n;
int stk[maxn], top;
int sum[maxn];
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", a + i);
int ans = 0;
a[0] = inf;
stk[++top] = 1;
sum[1] = 1;
for (int i = 2; i <= n; ++i) {
int pos = top;
while (a[stk[pos]] <= a[i] && pos)
--pos;
int delta = pos ? sum[top] - sum[pos] + 1 : sum[top];
ans += delta;
while (a[stk[top]] < a[i] && top)
--top;
if (a[stk[top]] == a[i])
++sum[top];
else {
stk[++top] = i;
sum[top] = sum[top - 1] + 1;
}
}
printf("%d\n", ans);
return 0;
}
算法2:分治
假设我们要统计从第 ℓ \ell ℓ个同学到第 r r r个同学这一段上能互相看见的同学的数量,我们将区间 [ l , r ] [l,r] [l,r]拆分成 [ l , m ] [l,m] [l,m]和 [ m + 1 , r ] [m+1,r] [m+1,r],其中 m = ⌊ ℓ + r 2 ⌋ m=\left\lfloor\frac{\ell+r}{2}\right\rfloor m=⌊2ℓ+r⌋。能够互相看见的两人 ( i , j ) (i,j) (i,j)要么都来自 [ l , m ] [l,m] [l,m],要么都来自 [ m + 1 , r ] [m+1,r] [m+1,r],要么 i ∈ [ l , m ] , j ∈ [ m + 1 , r ] i\in[l,m],j\in[m+1,r] i∈[l,m],j∈[m+1,r]。对于前两种情况只要递归调用即可。对于第三种情况,使用两个指针 i , j i,j i,j,一开始分别指向 m m m和 m + 1 m+1 m+1(这里的“指针”仅仅是一种形象的表述,在实现中就是两个下标,并不是C/C++中的那种指针)。
- 假如 a i < a j a_i<a_j ai<aj,并且 i i i和 j j j能互相看见,考虑将指针 i i i左移,找到 i i i左边第一个比 i i i高的同学 k k k,那么在 ( k , i ] (k,i] (k,i]中与 i i i一样高的同学都能与 j j j看见,将它们计入答案,并将 i i i赋值为 k k k。
- 假如 a i > a j a_i>a_j ai>aj并且 i i i和 j j j能互相看见,情况是类似的。
- 如果 a i = a j a_i=a_j ai=aj并且 i i i和 j j j能互相看见,我们需要同时左移 i i i和右移 j j j,找到 i i i左边第一个比 i i i高的同学 k 1 k_1 k1,以及 j j j右边第一个比 j j j高的同学 k 2 k_2 k2,那么在 ( k 1 , i ] (k_1,i] (k1,i]中和在 ( j , k 2 ] (j,k_2] (j,k2]中那些与 i , j i,j i,j一样高的同学都可以互相看见,所以我们分别统计出这两个区间中身高等于 a i a_i ai的同学数量 n 1 , n 2 n_1,n_2 n1,n2,那么就有 n 1 × n 2 n_1\times n_2 n1×n2对同学能互相看见,要计入答案。但同时, k 2 k_2 k2也能和左边的这 n 1 n_1 n1位同学看见, k 1 k_1 k1也能和右边的这 n 2 n_2 n2位同学看见,因此还要再加上 n 1 + n 2 n_1+n_2 n1+n2。注意, k 1 k_1 k1与 k 2 k_2 k2也能互相看见,但这会在下一轮中被统计,不要重复统计。
因此统计 i ∈ [ l , m ] , j ∈ [ m + 1 , r ] i\in[l,m],j\in[m+1,r] i∈[l,m],j∈[m+1,r]的情况,只需将两个指针分别从中点移到两端,时间复杂度 O ( n ) O(n) O(n),总的时间复杂度就是 O ( n log n ) O(n\log n) O(nlogn)。在实现中,我在左端点左边和右端点右边各设了一个巨人(身高为无穷大的同学),这样可以一定程度上减少边界考虑的负担。
#include <cstdio>
constexpr int maxn = 5e5 + 7;
constexpr int inf = 1e9;
int a[maxn], n;
int solve(int l, int r) {
if (r - l + 1 <= 1)
return 0;
int mid = (l + r) / 2;
int ans = 0;
int i = mid, j = mid + 1;
int tmpleft = a[l - 1], tmpright = a[r + 1];
a[l - 1] = a[r + 1] = inf;
while (i >= l && j <= r) {
if (a[i] < a[j]) {
int k = i;
while (a[k] <= a[i]) {
if (a[k] == a[i])
++ans;
--k;
}
i = k;
} else if (a[i] > a[j]) {
int k = j;
while (a[k] <= a[j]) {
if (a[k] == a[j])
++ans;
++k;
}
j = k;
} else {
int n1 = 0, n2 = 0, k = i;
while (a[k] <= a[i]) {
if (a[k] == a[i])
++n1;
--k;
}
i = k;
k = j;
while (a[k] <= a[j]) {
if (a[k] == a[j])
++n2;
++k;
}
j = k;
ans += n1 * n2;
if (i >= l)
ans += n2;
if (j <= r)
ans += n1;
}
}
a[l - 1] = tmpleft;
a[r + 1] = tmpright;
return ans + solve(l, mid) + solve(mid + 1, r);
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", a + i);
printf("%d\n", solve(1, n));
return 0;
}
PA2
2001
题意:给一个 n n n个结点的树,结点 i i i有一个权值 a i ∈ { 0 , 1 , 2 } a_i\in\{0,1,2\} ai∈{0,1,2}。如果一棵树中的所有非零权值都相等(都为 1 1 1或都为 2 2 2),则称这棵树为和平的。一条边被称为和平的,当且仅当将这条边删除之后得到的两棵树都是和平的。问和平边的数量。
首先说一下怎么存一棵树或者一张图。在绝大多数图算法或树算法中,最基本的操作都是遍历某个结点的所有neighbors,所以你的存储方式必须支持高效地访问每个点的neighbors。像邻接矩阵就不是一个好的存图方式,因为它必须花 O ( n ) O(n) O(n)的时间枚举点,再判断它是否与当前点相邻。通常的准则是:使用Floyd-Warshall算法时用邻接矩阵,使用Kruskal算法时就直接把边放进一个数组(因为需要给边排序),其它情况几乎总是使用邻接表。
邻接表怎么写?如果允许使用std::vector
,只要给每个结点开一个std::vector
存储它的邻居。下面给出存储、加边和dfs
的代码,注意这里的加边是加有向边,无向边可以被视为一正一反两条有向边。
std::vector<int> G[maxn];
inline void add_edge(int x, int y) {
G[x].push_back(y);
}
bool vis[maxn];
void dfs(int x) {
printf("visit %d\n", x);
vis[x] = true;
for (auto v : G[x])
if (!vis[v])
dfs(v);
}
但是CS101 PA不允许使用STL。当然你可以自己造一个std::vector
,它并不是很难,还能加深你对STL和其它一些C++语法知识的理解。在C++ Primer的第12、13章有介绍如何造std::vector
。
不允许使用STL的话,我们可以给每个结点开一个链表存放它的所有邻居,开一个数组存储每个链表的表头指针。加入一条边
(
x
,
y
)
(x,y)
(x,y)的时候,就是给第
x
x
x个链表增加一个元素,显然应该加在表头,所以就令新结点的next
指向head[x]
,再令head[x]
指向新结点即可。使用动态内存一定要记得释放!!!使用动态内存一定要记得释放!!!使用动态内存一定要记得释放!!!(据我所知80%的人写了new
都不知道要delete
,我不知道你们的CS100谁教的)
struct Node {
int to;
Node *next;
Node(int x, Node *n) : to(x), next(n) {}
};
Node *head[maxn];
inline void add_edge(int x, int y) {
Node *newnode = new Node(y, head[x]);
head[x] = newnode;
}
bool vis[maxn];
void dfs(int x) {
printf("visit %d\n", x);
vis[x] = true;
for (Node *p = head[x]; p; p = p->next) {
int v = p->to;
if (!vis[v])
dfs(v);
}
}
void clear_list(Node *p) {
if (!p)
return;
clear_list(p->next);
delete p;
}
inline void clear() {
for (int i = 1; i <= n; ++i)
clear_list(head[i]);
}
如果你嫌动态内存麻烦,不熟悉,容易出错,那么你可以采用下面的数组模拟链表的方式(这其实是一些人所说的“前向星”)。我们需要开一个大数组G
来存放链表的结点,每次新建结点的时候就从G
中拿一个没被用过的。用一个变量total
来表示现在G
里有多少个结点已被使用。原来的指针值变成了结点在G
中的下标。由于不涉及动态内存,我们也不需要去delete
。注意,G
的大小是你总共需要添加的边的数量,如果是有向图就是
∣
E
∣
|E|
∣E∣,无向图则是
2
∣
E
∣
2|E|
2∣E∣,树是
2
∣
V
∣
2|V|
2∣V∣。代码中用maxe
表示总边数。
struct Node {
int to;
int next;
};
Node G[maxe];
int head[maxn];
int total;
inline void add_edge(int x, int y) {
G[++total].to = y;
G[total].next = head[x];
head[x] = total;
}
bool vis[maxn];
void dfs(int x) {
printf("visit %d\n", x);
vis[x] = true;
for (int i = head[x]; i; i = G[i].next) {
int v = G[i].to;
if (!vis[v])
dfs(v);
}
}
现在来看题。首先我们将这棵树任意挑一个点(比如
1
1
1)为根,变成一棵有根树。令
f
1
(
x
)
f_1(x)
f1(x)表示以结点
x
x
x为根的子树中有多少个点的权值为
1
1
1,令
f
2
(
x
)
f_2(x)
f2(x)表示以结点
x
x
x为根的子树中有多少个点的权值为
2
2
2。我们可以通过一遍dfs来计算出所有的
f
1
(
x
)
f_1(x)
f1(x)和
f
2
(
x
)
f_2(x)
f2(x),因为
f
1
(
x
)
=
[
a
x
=
1
]
+
∑
v
∈
child
(
x
)
f
1
(
v
)
,
f
2
(
x
)
=
[
a
x
=
2
]
+
∑
v
∈
child
(
x
)
f
2
(
v
)
.
f_1(x)=[a_x=1]+\sum_{v\in\operatorname{child}(x)}f_1(v),\quad f_2(x)=[a_x=2]+\sum_{v\in\operatorname{child}(x)}f_2(v).
f1(x)=[ax=1]+v∈child(x)∑f1(v),f2(x)=[ax=2]+v∈child(x)∑f2(v).
计算出这些信息有什么用呢?我们发现,假设我们要删除一条边
(
x
,
v
)
(x,v)
(x,v),其中
v
v
v是
x
x
x的儿子,删除这条边之后得到的两棵树分别是以
v
v
v为根的子树和除了以
v
v
v为根的子树之外剩下的部分。如果这两棵树都是和平的,这条边就是和平的。而一棵树是和平的,就意味着它要么不含权值为
1
1
1的点,要么不含权值为
2
2
2的点,所以只要判断f1[v] == 0 || f2[v] == 0
就可以判断以
v
v
v为根的子树是否是和平的。剩下的部分中权值为
1
1
1的点的数量就是总数减去f1[v]
,所以我们在输入的时候就计算出sum1
和sum2
,分别表示权值为
1
1
1和
2
2
2的点一共有多少个,那么sum1 - f1[v] == 0 || sum2 - f2[v] == 0
就判断了剩下的部分是否是和平的。整个算法只要做一遍dfs,时间复杂度
O
(
n
)
O(n)
O(n)。
#include <cstdio>
constexpr int maxn = 3e5 + 7;
struct Node {
int to, next;
};
Node G[maxn * 2];
int total, n;
int head[maxn];
int sum1, sum2, f1[maxn], f2[maxn];
int a[maxn];
int ans;
inline void add_edge(int x, int y) {
G[++total].to = y;
G[total].next = head[x];
head[x] = total;
}
void dfs(int x, int fa) {
if (a[x] == 1)
f1[x] = 1;
if (a[x] == 2)
f2[x] = 1;
for (int i = head[x]; i; i = G[i].next) {
int v = G[i].to;
if (v != fa) {
dfs(v, x);
f1[x] += f1[v];
f2[x] += f2[v];
if ((f1[v] == 0 || f2[v] == 0) && (sum1 - f1[v] == 0 || sum2 - f2[v] == 0))
++ans;
}
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", a + i);
if (a[i] == 1)
++sum1;
if (a[i] == 2)
++sum2;
}
for (int i = 1; i < n; ++i) {
int x, y;
scanf("%d%d", &x, &y);
add_edge(x, y);
}
dfs(1, 0);
printf("%d\n", ans);
return 0;
}
2002
题意:给定
n
n
n个正整数
s
1
,
⋯
,
s
n
s_1,\cdots,s_n
s1,⋯,sn,称一个整数区间
[
ℓ
,
r
]
[\ell,r]
[ℓ,r]为“和谐的”,指的是存在整数
t
⩾
2
t\geqslant 2
t⩾2使得
s
ℓ
≡
s
ℓ
+
1
≡
⋯
≡
s
r
(
m
o
d
t
)
.
s_\ell\equiv s_{\ell+1}\equiv\cdots\equiv s_r\pmod t.
sℓ≡sℓ+1≡⋯≡sr(modt).
求长度最长的和谐区间。
如果区间
[
ℓ
,
r
]
[\ell,r]
[ℓ,r]是和谐的,意味着
∃
t
⩾
2
,
∣
s
ℓ
−
s
ℓ
+
1
∣
≡
∣
s
ℓ
+
1
−
s
ℓ
+
2
∣
≡
⋯
≡
∣
s
r
−
1
−
s
r
∣
≡
0
(
m
o
d
t
)
.
\exists t\geqslant 2,\quad \left|s_\ell-s_{\ell+1}\right|\equiv\left|s_{\ell+1}-s_{\ell+2}\right|\equiv\cdots\equiv\left|s_{r-1}-s_r\right|\equiv 0\pmod t.
∃t⩾2,∣sℓ−sℓ+1∣≡∣sℓ+1−sℓ+2∣≡⋯≡∣sr−1−sr∣≡0(modt).
令
d
i
=
∣
s
i
−
1
−
s
i
∣
d_i=\left|s_{i-1}-s_i\right|
di=∣si−1−si∣,则
∃
t
⩾
2
,
d
ℓ
+
1
≡
d
ℓ
+
2
≡
⋯
≡
d
r
≡
0
(
m
o
d
t
)
,
\exists t\geqslant 2,\quad d_{\ell+1}\equiv d_{\ell+2}\equiv\cdots\equiv d_r\equiv 0\pmod t,
∃t⩾2,dℓ+1≡dℓ+2≡⋯≡dr≡0(modt),
所以
∃
t
⩾
2
,
∀
i
∈
[
ℓ
+
1
,
r
]
,
t
∣
d
i
,
\exists t\geqslant 2,\forall i\in[\ell+1,r],t\mid d_i,
∃t⩾2,∀i∈[ℓ+1,r],t∣di,
那么
∃
t
⩾
2
,
t
∣
gcd
(
d
ℓ
+
1
,
d
ℓ
+
2
,
⋯
,
d
r
)
,
⟺
gcd
(
d
ℓ
+
1
,
d
ℓ
+
2
,
⋯
,
d
r
)
≠
1.
\exists t\geqslant 2,t\mid\gcd(d_{\ell+1},d_{\ell+2},\cdots,d_r),\iff \gcd(d_{\ell+1},d_{\ell+2},\cdots,d_r)\neq1.
∃t⩾2,t∣gcd(dℓ+1,dℓ+2,⋯,dr),⟺gcd(dℓ+1,dℓ+2,⋯,dr)=1.
注意,我这里写的是
≠
1
\neq 1
=1而不是
⩾
2
\geqslant 2
⩾2,因为可能会出现
gcd
=
0
\gcd=0
gcd=0的情况,这也是符合的。 于是我们要做的就是在
d
2
,
⋯
,
d
n
d_2,\cdots,d_n
d2,⋯,dn上找到最长的区间,使得这个区间里的元素不互素。注意一个细节:当“和谐区间”长度为
2
2
2的时候,比方说是
[
i
,
i
+
1
]
[i,i+1]
[i,i+1],对应的
d
d
d序列上的不互素区间长度为
1
1
1,相当于是要判断
gcd
(
d
i
)
≠
1
\gcd(d_i)\neq 1
gcd(di)=1,而一个数的“
gcd
\gcd
gcd”该怎么定义呢?稍加思索,发现定义
gcd
(
x
)
=
x
\gcd(x)=x
gcd(x)=x是合理的(为什么?)。此外,
d
d
d序列中的元素可能会有
0
0
0,但数学一般不考虑
0
0
0与其它数的
gcd
\gcd
gcd。我们可以定义
gcd
(
x
,
0
)
=
x
\gcd(x,0)=x
gcd(x,0)=x(为什么?)。
我们考虑枚举左端点
ℓ
\ell
ℓ,尝试求出最远的右端点
r
r
r,使得
gcd
(
d
ℓ
,
⋯
,
d
r
)
≠
1
\gcd(d_\ell,\cdots,d_r)\neq 1
gcd(dℓ,⋯,dr)=1(自然
gcd
(
d
ℓ
,
⋯
,
d
r
+
1
)
=
1
\gcd(d_\ell,\cdots,d_{r+1})=1
gcd(dℓ,⋯,dr+1)=1)。注意,随着右端点
r
r
r向右移动,区间里的元素增加,
gcd
\gcd
gcd是单调不增的。对于一个固定的左端点
ℓ
\ell
ℓ来说,我们自然可以将右端点
r
r
r从
ℓ
\ell
ℓ开始逐步右移,并注意到
gcd
(
d
ℓ
,
⋯
,
d
r
)
=
gcd
(
gcd
(
d
ℓ
,
⋯
,
d
r
−
1
)
,
d
r
)
,
(2.1)
\gcd(d_\ell,\cdots,d_r)=\gcd(\gcd(d_\ell,\cdots,d_{r-1}),d_r),\tag{2.1}
gcd(dℓ,⋯,dr)=gcd(gcd(dℓ,⋯,dr−1),dr),(2.1)
于是花
O
(
n
log
A
)
O(n\log A)
O(nlogA)的时间求出最远的右端点
r
r
r,这里
O
(
log
A
)
O(\log A)
O(logA)是求
gcd
\gcd
gcd的时间复杂度,其中
A
A
A是
d
d
d的最大值。但是由于枚举左端点已经
O
(
n
)
O(n)
O(n),总的复杂度就达到了
O
(
n
2
log
A
)
O(n^2\log A)
O(n2logA),这不能通过。
假如对于某个左端点 ℓ \ell ℓ,我们已经知道它对应的最远右端点是 r = r ( ℓ ) r=r(\ell) r=r(ℓ),我们有没有办法快一点求出下一个左端点 ℓ + 1 \ell+1 ℓ+1对应的右端点 r ( ℓ + 1 ) r(\ell+1) r(ℓ+1)呢?我们发现当区间从 [ ℓ , r ] [\ell,r] [ℓ,r]变成 [ ℓ + 1 , r ] [\ell+1,r] [ℓ+1,r]时,由于区间长度缩短, gcd \gcd gcd单调不减,所以必然有 r ( ℓ + 1 ) ⩾ r ( ℓ ) r(\ell+1)\geqslant r(\ell) r(ℓ+1)⩾r(ℓ),也就是说我们没必要再从 ℓ + 1 \ell+1 ℓ+1开始往右找,而是直接从 r ( ℓ ) r(\ell) r(ℓ)开始。但这时问题来了,递推式 ( 2.1 ) (2.1) (2.1)给出了当我们向区间中增加一个元素时快速求出 gcd \gcd gcd的方式,但是左端点从 ℓ \ell ℓ变成 ℓ + 1 \ell+1 ℓ+1时是从区间中删除一个数,这时该如何快速求出 [ ℓ + 1 , r ] [\ell+1,r] [ℓ+1,r]的 gcd \gcd gcd?
算法1:查询区间 gcd \gcd gcd
这个算法需要一个超出本课程范围的知识,如果你不想学可以跳到算法2。
现在有这样一个问题:给定一个非负整数序列 d 1 , ⋯ , d n d_1,\cdots,d_n d1,⋯,dn,现在有 q q q次询问,第 i i i次询问是一个区间 [ ℓ i , r i ] [\ell_i,r_i] [ℓi,ri],请你快速回答 gcd ( d ℓ i , d ℓ i + 1 , ⋯ , d r i ) \gcd\left(d_{\ell_i},d_{\ell_i+1},\cdots,d_{r_i}\right) gcd(dℓi,dℓi+1,⋯,dri),即这个区间上的 gcd \gcd gcd。 n ⩽ 1 0 5 , q ⩽ 1 0 5 n\leqslant 10^5,q\leqslant 10^5 n⩽105,q⩽105。
显然你不能花 O ( n log A ) O(n\log A) O(nlogA)的时间去遍历区间来回答单次询问,我们得事先预处理一些答案。如果开一个二维数组 f ( ℓ , r ) = gcd ( d ℓ , d ℓ + 1 , ⋯ , d r ) f(\ell,r)=\gcd(d_\ell,d_{\ell+1},\cdots,d_r) f(ℓ,r)=gcd(dℓ,dℓ+1,⋯,dr),在回答所有询问之前就把这些答案都预处理出来,那我们就能做到 O ( 1 ) O(1) O(1)回答每次询问了。但这样肯定不行,因为这个二维数组的空间复杂度达到了 O ( n 2 ) O(n^2) O(n2),并且预处理这个二维数组的时间复杂度也达到了 O ( n 2 log A ) O(n^2\log A) O(n2logA)。这个预处理之所以失败,是因为我们把所有的工作量都压在了预处理阶段;而如果每次询问都花 O ( n log A ) O(n\log A) O(nlogA)的时间遍历区间,相当于把所有的工作量都放在了询问阶段。这是两个极端。
那么折中的思想就是:预处理一部分答案,想办法在查询的时候利用这一部分答案来获得想要的答案。ST表(稀疏表,Sparse Table)正是基于这一思想:我们令
f
(
i
,
j
)
f(i,j)
f(i,j)表示左端点为
i
i
i,长度为
2
j
2^j
2j的区间上的
gcd
\gcd
gcd。由于区间长度是
2
j
⩽
n
2^j\leqslant n
2j⩽n,所以这个二维数组的空间复杂度是
O
(
n
log
n
)
O(n\log n)
O(nlogn),可以接受。那怎么计算呢?我们把区间
[
i
,
i
+
2
j
−
1
]
\left[i,i+2^j-1\right]
[i,i+2j−1]切成
[
i
,
i
+
2
j
−
1
−
1
]
\left[i,i+2^{j-1}-1\right]
[i,i+2j−1−1]和
[
i
+
2
j
−
1
,
i
+
2
j
−
1
]
\left[i+2^{j-1},i+2^j-1\right]
[i+2j−1,i+2j−1]两段长度均为
2
j
−
1
2^{j-1}
2j−1的区间,那么
[
i
,
i
+
2
j
−
1
]
\left[i,i+2^j-1\right]
[i,i+2j−1]的
gcd
\gcd
gcd就是这两个子区间的
gcd
\gcd
gcd的
gcd
\gcd
gcd,这其实可以看做一种动态规划,其转移方程是
f
(
i
,
j
)
=
gcd
(
f
(
i
,
j
−
1
)
,
f
(
i
+
2
j
−
1
,
j
−
1
)
)
.
f(i,j)=\gcd\left(f(i,j-1),f\left(i+2^{j-1},j-1\right)\right).
f(i,j)=gcd(f(i,j−1),f(i+2j−1,j−1)).
所以我们可以花
O
(
log
A
)
O(\log A)
O(logA)的时间由
f
(
i
,
j
−
1
)
f(i,j-1)
f(i,j−1)和
f
(
i
+
2
j
−
1
,
j
−
1
)
f\left(i+2^{j-1},j-1\right)
f(i+2j−1,j−1)递推出
f
(
i
,
j
)
f(i,j)
f(i,j)。注意,计算
2
j
2^j
2j是
O
(
1
)
O(1)
O(1)的,只要用位运算1 << j
即可。还要注意计算的顺序,你必须保证在计算
f
(
i
,
j
)
f(i,j)
f(i,j)的时候
f
(
∗
,
j
−
1
)
f(*,j-1)
f(∗,j−1)已经算过,下面是一个典型的错误写法:
for (int i = 1; i <= n; ++i)
for (int j = 1; i + (1 << j) - 1 <= n; ++j)
f[i][j] = gcd(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
你必须这样写:
for (int j = 1; (1 << j) <= n; ++j)
for (int i = 1; i + (1 << j) - 1 <= n; ++j)
f[i][j] = gcd(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
并且在此之前先计算base case
for (int i = 1; i <= n; ++i)
f[i][0] = d[i];
这样我们就在
O
(
n
log
n
log
A
)
O(n\log n\log A)
O(nlognlogA)的时间内预处理出了这张ST表。现在考虑如何处理一个询问
[
ℓ
,
r
]
[\ell,r]
[ℓ,r]。我们找到最大的
k
k
k,满足
2
k
⩽
r
−
ℓ
+
1
2^k\leqslant r-\ell+1
2k⩽r−ℓ+1,那么
[
ℓ
,
ℓ
+
2
k
−
1
]
⊆
[
ℓ
,
r
]
,
[
r
−
2
k
+
1
,
r
]
⊆
[
ℓ
,
r
]
\left[\ell,\ell+2^k-1\right]\subseteq[\ell,r],\left[r-2^k+1,r\right]\subseteq[\ell,r]
[ℓ,ℓ+2k−1]⊆[ℓ,r],[r−2k+1,r]⊆[ℓ,r],并且
[
ℓ
,
ℓ
+
2
k
−
1
]
∪
[
r
−
2
k
+
1
,
r
]
=
[
ℓ
,
r
]
\left[\ell,\ell+2^k-1\right]\cup\left[r-2^k+1,r\right]=[\ell,r]
[ℓ,ℓ+2k−1]∪[r−2k+1,r]=[ℓ,r],所以区间
[
ℓ
,
r
]
[\ell,r]
[ℓ,r]上的
gcd
\gcd
gcd就是
gcd
(
f
(
ℓ
,
k
)
,
f
(
r
−
2
k
+
1
,
k
)
)
.
\gcd\left(f(\ell,k),f(r-2^k+1,k)\right).
gcd(f(ℓ,k),f(r−2k+1,k)).
这里的
k
k
k显然等于
⌊
log
2
(
r
−
ℓ
+
1
)
⌋
\left\lfloor\log_2(r-\ell+1)\right\rfloor
⌊log2(r−ℓ+1)⌋。这个计算可以调用cmath
里的相关函数来完成,但我们也可以预处理这些数:设
g
(
i
)
=
⌊
log
2
i
⌋
g(i)=\left\lfloor\log_2i\right\rfloor
g(i)=⌊log2i⌋,则
g
(
i
)
=
g
(
⌊
i
/
2
⌋
)
+
1
g(i)=g\left(\lfloor i/2\rfloor\right)+1
g(i)=g(⌊i/2⌋)+1,可以
O
(
n
)
O(n)
O(n)预处理出来。于是我们做到了
O
(
log
A
)
O(\log A)
O(logA)回答单次询问。
inline int query(int l, int r) {
int k = g[r - l + 1];
return gcd(f[l][k], f[r - (1 << k) + 1][k]);
}
有了这个核武器,前面的算法就可以实现了:设置一个左端点 ℓ \ell ℓ和它对应的最远右端点 r r r,每次将左端点 ℓ \ell ℓ右移的时候,尝试右移右端点 r r r直到 gcd ( d ℓ , ⋯ , d r ) ≠ 1 \gcd(d_\ell,\cdots,d_r)\neq1 gcd(dℓ,⋯,dr)=1而 gcd ( d ℓ , ⋯ , d r + 1 ) = 1 \gcd(d_\ell,\cdots,d_{r+1})=1 gcd(dℓ,⋯,dr+1)=1,这个过程中每次求 gcd \gcd gcd都只要 O ( log A ) O(\log A) O(logA)的时间。由于 ℓ \ell ℓ和 r r r都是单向移动的,每个位置都至多被 ℓ \ell ℓ经过一次、被 r r r经过一次,所以时间复杂度 O ( n log A ) O(n\log A) O(nlogA)。再加上预处理ST表的复杂度,总的复杂度 O ( n log n log A + n log A ) = O ( n log n log A ) O(n\log n\log A+n\log A)=O(n\log n\log A) O(nlognlogA+nlogA)=O(nlognlogA)。实现的时候,移动左右端点的部分要格外小心,细节容易写错。另外, d d d上的不互素区间长度 + 1 +1 +1才是 s s s上的和谐区间长度。
#include <cstdio>
#include <cmath>
constexpr int maxn = 2e3 + 7;
int s[maxn], d[maxn], n, T;
int f[maxn][1 << 14], g[maxn];
int gcd(int a, int b) {
return b ? gcd(b, a % b) : a;
}
inline void init() {
for (int i = 2; i <= n; ++i)
g[i] = g[i / 2] + 1;
for (int i = 1; i <= n; ++i)
f[i][0] = d[i];
for (int j = 1; 1 << j <= n; ++j)
for (int i = 1; i + (1 << j) - 1 <= n; ++i)
f[i][j] = gcd(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
inline int query(int l, int r) {
int k = g[r - l + 1];
return gcd(f[l][k], f[r - (1 << k) + 1][k]);
}
int main() {
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", s + i);
d[i] = abs(s[i] - s[i - 1]);
}
init();
int ans = 1;
for (int l = 2, r = 1; l <= n; ++l) {
if (r < l - 1)
r = l - 1;
while (r < n && query(l, r + 1) != 1)
++r;
if (r - l + 2 > ans)
ans = r - l + 2;
}
printf("%d\n", ans);
}
return 0;
}
此外,这个区间查询问题还可以用线段树实现,可以做到 O ( n log A ) O(n\log A) O(nlogA)预处理, O ( log n log A ) O(\log n\log A) O(lognlogA)回答单次询问。有兴趣的同学可以自行查找资料学习。线段树是一个非常强大的数据结构,我会在后面的题目中再次提到。
算法2:分治
考虑求出 [ ℓ , r ] [\ell,r] [ℓ,r]上的最长不互素区间,老规矩设中点 m = ⌊ ℓ + r 2 ⌋ m=\left\lfloor\frac{\ell+r}2\right\rfloor m=⌊2ℓ+r⌋,将区间分成 [ ℓ , m ] [\ell,m] [ℓ,m]和 [ m + 1 , r ] [m+1,r] [m+1,r]两部分。不互素的区间要么是 [ ℓ , m ] [\ell,m] [ℓ,m]的子区间,要么是 [ m + 1 , r ] [m+1,r] [m+1,r]的子区间,要么左端点在 [ ℓ , m ] [\ell,m] [ℓ,m]中而右端点在 [ m + 1 , r ] [m+1,r] [m+1,r]中。对于前两种情况,直接递归调用,我们主要考虑第三种情况。
仍然取一个左端点
i
∈
[
ℓ
,
m
]
i\in[\ell,m]
i∈[ℓ,m],考虑找到
[
m
+
1
,
r
]
[m+1,r]
[m+1,r]中最远的右端点
j
j
j,使得区间
[
i
,
j
]
[i,j]
[i,j]上的元素不互素而区间
[
i
,
j
+
1
]
[i,j+1]
[i,j+1]上的元素互素。假如我们已经求出了这个
j
=
j
(
i
)
j=j(i)
j=j(i),当左端点
i
i
i变成
i
+
1
i+1
i+1时,能否快速求出
j
(
i
+
1
)
j(i+1)
j(i+1)呢?基于之前的讨论,
j
(
i
+
1
)
⩾
j
(
i
)
j(i+1)\geqslant j(i)
j(i+1)⩾j(i)仍然是成立的,我们每次只要将右端点
j
j
j继续右移即可,不需要从
i
+
1
i+1
i+1开始找。现在我们仍然面对快速查询
gcd
\gcd
gcd的问题,但是我们发现要查询
gcd
\gcd
gcd的区间
[
i
,
j
]
[i,j]
[i,j]一定满足
i
⩽
m
,
m
+
1
⩽
j
i\leqslant m,m+1\leqslant j
i⩽m,m+1⩽j,这个区间其实是
[
ℓ
,
m
]
[\ell,m]
[ℓ,m]的一段后缀和
[
m
+
1
,
r
]
[m+1,r]
[m+1,r]的一段前缀。如果设
p
(
i
)
=
gcd
(
d
i
,
d
i
+
1
,
⋯
,
d
m
)
,
p(i)=\gcd(d_i,d_{i+1},\cdots,d_m),
p(i)=gcd(di,di+1,⋯,dm),
q ( j ) = gcd ( d m + 1 , d m + 2 , ⋯ , d j ) , q(j)=\gcd(d_{m+1},d_{m+2},\cdots,d_j), q(j)=gcd(dm+1,dm+2,⋯,dj),
则区间
[
i
,
j
]
[i,j]
[i,j]上的
gcd
\gcd
gcd就是
gcd
(
p
(
i
)
,
q
(
j
)
)
\gcd(p(i),q(j))
gcd(p(i),q(j))。而根据式
(
2.1
)
(2.1)
(2.1)的思想,
p
(
i
)
p(i)
p(i)和
q
(
j
)
q(j)
q(j)又满足如下递推关系
p
(
i
)
=
gcd
(
p
(
i
+
1
)
,
d
i
)
,
p(i)=\gcd(p(i+1),d_i),
p(i)=gcd(p(i+1),di),
q ( j ) = gcd ( q ( j − 1 ) , d j ) . q(j)=\gcd(q(j-1),d_j). q(j)=gcd(q(j−1),dj).
因此,假设 r − ℓ + 1 = n ′ r-\ell+1=n^\prime r−ℓ+1=n′,我们可以花 O ( n ′ log A ) O(n^\prime\log A) O(n′logA)的时间先预处理所有 p ( i ) , q ( j ) p(i),q(j) p(i),q(j),然后设置一个左端点 i i i和右端点 j j j,在右移左端点 i i i的时候不断尝试右移右端点,这个过程中需要不断检查 gcd \gcd gcd是否为 1 1 1,而每次检查都只需要 O ( log A ) O(\log A) O(logA)的时间。由于 i i i和 j j j都是单向移动的,每个位置都至多被 i i i和 j j j各经过一次,所以总的时间复杂度是 O ( n ′ log A ) O(n^\prime\log A) O(n′logA)。于是整个分治算法的时间复杂度为 O ( n log n log A ) O(n\log n\log A) O(nlognlogA)。
实现的时候要极其小心,这个移动双端点的过程有一些细节容易写错。
#include <cstdio>
#include <cmath>
constexpr int maxn = 2e3 + 7;
int s[maxn], d[maxn], n, T;
int p[maxn], q[maxn];
int ans;
int gcd(int a, int b) {
return b ? gcd(b, a % b) : a;
}
void solve(int l, int r) {
if (l > r)
return;
if (r - l + 1 < ans)
return;
if (l == r) {
if (d[l] != 1 && 2 > ans)
ans = 2;
return;
}
int mid = (l + r) >> 1;
solve(l, mid);
solve(mid + 1, r);
if (gcd(d[mid], d[mid + 1]) == 1)
return;
p[mid] = d[mid];
for (int i = mid - 1; i >= l; --i)
p[i] = gcd(p[i + 1], d[i]);
q[mid + 1] = d[mid + 1];
for (int i = mid + 2; i <= r; ++i)
q[i] = gcd(q[i - 1], d[i]);
for (int i = l, j = mid; i <= mid; ++i) {
while (j < r && gcd(p[i], q[j + 1]) != 1)
++j;
if (j > mid && j - i + 2 > ans)
ans = j - i + 2;
}
}
int main() {
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", s + i);
d[i] = abs(s[i] - s[i - 1]);
}
ans = 1;
solve(2, n);
printf("%d\n", ans);
}
return 0;
}
算法3:二分答案
前两种算法都是基于所谓的“尺取法”,即使用两个单调移动的指针来解决问题。但如果你学会了算法1中快速查询区间 gcd \gcd gcd的方式,我们就能得到下面这种写起来细节更少,不容易写错的算法。我们发现,如果有一个长度为 5 5 5的不互素区间,那么一定存在一个长度为 3 3 3的不互素区间。也就是说,这里的答案蕴含了某种单调性在里面。
二分是我们都玩过的一个游戏:Alice心里想了一个 [ 1 , 100 ] [1,100] [1,100]之间的整数,Bob来猜,每次Alice可以告诉Bob他猜大了还是猜小了。Bob每次都猜当前可行区间的中点,就保证每次都将可行区间缩短至少一半,于是至多 7 = ⌈ log 2 100 ⌉ 7=\lceil\log_2100\rceil 7=⌈log2100⌉次就一定能猜出答案。
现在我们用这种二分的方式来猜最长的不互素区间长度是多少。比方说我猜 m m m,那接下来就要检查是否存在一个长度为 m m m的不互素区间。检查的方式非常简单,我们枚举区间的左端点 i i i,看看 [ i , i + m − 1 ] [i,i+m-1] [i,i+m−1]这段区间上的 gcd \gcd gcd是否等于 1 1 1就可以了。使用ST表可以做到 O ( log A ) O(\log A) O(logA)查询,那么检查的复杂度是 O ( n log A ) O(n\log A) O(nlogA),二分的复杂度 O ( log n ) O(\log n) O(logn),总复杂度 O ( n log n log A ) O(n\log n\log A) O(nlognlogA),ST表预处理的复杂度也是 O ( n log n log A ) O(n\log n\log A) O(nlognlogA),所以整个算法的复杂度 O ( n log n log A ) O(n\log n\log A) O(nlognlogA)。
#include <cstdio>
#include <cmath>
constexpr int maxn = 2e3 + 7;
int s[maxn], d[maxn], n, T;
int f[maxn][1 << 14], g[maxn];
int gcd(int a, int b) {
return b ? gcd(b, a % b) : a;
}
inline void init() {
for (int i = 2; i <= n; ++i)
g[i] = g[i / 2] + 1;
for (int i = 1; i <= n; ++i)
f[i][0] = d[i];
for (int j = 1; 1 << j <= n; ++j)
for (int i = 1; i + (1 << j) - 1 <= n; ++i)
f[i][j] = gcd(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
inline int query(int l, int r) {
int k = g[r - l + 1];
return gcd(f[l][k], f[r - (1 << k) + 1][k]);
}
inline bool check(int mid) {
for (int i = 2; i + mid - 1 <= n; ++i)
if (query(i, i + mid - 1) != 1)
return true;
return false;
}
int main() {
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", s + i);
d[i] = abs(s[i] - s[i - 1]);
}
init();
// 二分答案的写法有很多种,有的人没有这个 ans 变量,有的人判断条件是 left < right,
// 有的人修改的时候是 left = mid 或 right = mid,等等。
// 最重要的是选定一种写法然后一直坚持,这样不容易写错。下面是我自己一直使用的写法。
int left = 1, right = n - 1, ans = 0;
while (left <= right) {
int mid = (left + right) / 2;
if (check(mid)) {
ans = mid;
left = mid + 1;
} else
right = mid - 1;
}
printf("%d\n", ans + 1);
}
return 0;
}
此外,我们还可以枚举左端点 i i i,这时它所对应的最远右端点 j j j也是单调的,也可以二分出来。时间复杂度是一样的。
2003
题意:有 c c c只鸽子,每只鸽子有一个饥饿程度 a i a_i ai和需要的食物数量 b i b_i bi。现在要选出 n n n只来喂,但你拥有的食物只有 f f f,也就是说所有选中的鸽子的 b b b值之和不能超过 f f f。在这种情况下最大化选中的鸽子的 a a a值的中位数。保证 n n n是奇数。无解输出 − 1 -1 −1。 n ⩽ 1 0 5 , n ⩽ c ⩽ 2 × 1 0 5 , 0 ⩽ f ⩽ 2 × 1 0 9 , 0 ⩽ a i ⩽ 2 × 1 0 9 , 0 ⩽ b i ⩽ 1 0 5 n\leqslant 10^5,n\leqslant c\leqslant 2\times 10^5,0\leqslant f\leqslant 2\times 10^9,0\leqslant a_i\leqslant 2\times 10^9,0\leqslant b_i\leqslant 10^5 n⩽105,n⩽c⩽2×105,0⩽f⩽2×109,0⩽ai⩽2×109,0⩽bi⩽105。
我特意把全部的数据范围放上来了,因为稍加计算会发现,此题的程序中某些变量的值可能超过int
所能表示的最大值。在OI赛事中不乏这样的例子,一些正确的程序因为没有开long long
而丢掉了部分分数,非常可惜,大家一定要小心。本题中会涉及到将若干个
b
b
b值相加,最坏情况下则是将
1
0
5
10^5
105个
b
b
b值相加,会达到
1
0
10
10^{10}
1010,而int
的最大值是
2
31
−
1
=
2147483647
2^{31}-1=2147483647
231−1=2147483647,所以某些变量需要用long long
。但我们不建议非常简单粗暴地将所有整型都开成long long
,更不建议效仿某些人#define int long long
的做法,因为那会额外增加很多空间开销,也会降低程序可读性。
我们考虑将所有鸽子按 a a a值从大到小排序,然后枚举中位数来自哪只鸽子。假设中位数是 a i a_i ai,意味着我们要在 [ 1 , i − 1 ] [1,i-1] [1,i−1]和 [ i + 1 , c ] [i+1,c] [i+1,c]中分别选取 k = ( n − 1 ) / 2 k=(n-1)/2 k=(n−1)/2只鸽子,要求选出的所有鸽子的 b b b值之和不超过 f f f,显然这些鸽子的 b b b值应该越小越好。因此我们需要在 [ 1 , i − 1 ] [1,i-1] [1,i−1]这些鸽子中找出前 k k k小的 b b b值之和,在 [ i + 1 , c ] [i+1,c] [i+1,c]这些鸽子中找出前 k k k小的 b b b值之和,然后跟 b i b_i bi加在一起判断是否小于等于 f f f。
现在问题就在于,如何快速查询这两段区间中前 k k k小的 b b b值。你当然可以直接用主席树、平衡树秒杀,正所谓“智商不够,数据结构来凑”,不过我们还是说一说课程范围以内的做法。
记 g ( i ) g(i) g(i)表示区间 [ 1 , i − 1 ] [1,i-1] [1,i−1]上的鸽子中最小的 k k k个 b b b值之和,这里 i ⩾ k + 1 i\geqslant k+1 i⩾k+1。维护一个始终都恰好有 k k k个元素的最大堆,一开始先把 b 1 , ⋯ , b k b_1,\cdots,b_k b1,⋯,bk加入这个堆,同时记录当前堆中的元素之和 s u m sum sum。在 i i i从 k + 1 k+1 k+1起逐渐递增的过程中,每次 g ( i ) g(i) g(i)就等于当前的 s u m sum sum,然后判断 b i b_i bi是否应该加入堆。注意,这时的堆中的元素应为 b 1 , ⋯ , b i − 1 b_1,\cdots,b_{i-1} b1,⋯,bi−1中最小的 k k k个,而堆顶正好是这 k k k个中的最大值,也就是第 k k k小值。如果 b i b_i bi小于堆顶,就令 s u m sum sum减去堆顶,把堆顶弹掉,把 b i b_i bi加入堆,并令 s u m sum sum增加 b i b_i bi;否则就不将 b i b_i bi加入堆。这样一来,堆中的元素就变成了 b 1 , ⋯ , b i b_1,\cdots,b_i b1,⋯,bi中最小的 k k k个,下一轮迭代就可以用了。记 h ( i ) h(i) h(i)表示区间 [ i + 1 , c ] [i+1,c] [i+1,c]上的鸽子中最小的 k k k个 b b b值之和,其中 i ⩽ c − k i\leqslant c-k i⩽c−k。将以上过程倒着做一遍,就求出了所有的 h ( i ) h(i) h(i)。
在实现中,因为我们最终枚举中位数的时候 i i i是从小到大枚举的,所以 g ( i ) g(i) g(i)可以不用开数组预处理,一边枚举中位数一边算即可。
#include <cstdio>
#include <functional>
#include <iterator>
namespace gkxx {
template <typename ForwardIterator, typename Less>
void __inplace_merge(
ForwardIterator begin, ForwardIterator mid, ForwardIterator end,
typename std::iterator_traits<ForwardIterator>::difference_type dist,
Less less) {
ForwardIterator i = begin, j = mid;
using value_type = typename std::iterator_traits<ForwardIterator>::value_type;
value_type *tmp = new value_type[dist](), *k = tmp;
while (i != mid && j != end) {
if (less(*i, *j))
*k++ = *i++;
else
*k++ = *j++;
}
while (i != mid)
*k++ = *i++;
while (j != end)
*k++ = *j++;
k = tmp;
while (begin != end)
*begin++ = *k++;
delete[] tmp;
}
template <typename ForwardIterator, typename Less>
void merge_sort(ForwardIterator begin, ForwardIterator end, Less less) {
auto dist = std::distance(begin, end);
if (dist <= 1)
return;
ForwardIterator mid = std::next(begin, dist / 2);
merge_sort(begin, mid, less);
merge_sort(mid, end, less);
__inplace_merge(begin, mid, end, dist, less);
}
template <typename ForwardIterator>
inline void merge_sort(ForwardIterator begin, ForwardIterator end) {
merge_sort(begin, end, std::less<void>());
}
} // namespace gkxx
constexpr int maxn = 2e5 + 7;
struct Pigeon {
int a, b;
};
Pigeon pg[maxn];
int f, c, n;
int heap[maxn], size;
void push(int x) {
int i = ++size;
while (i > 1) {
int fa = i / 2;
if (heap[fa] < x)
heap[i] = heap[fa];
else
break;
i = fa;
}
heap[i] = x;
}
void pop() {
int x = heap[size--];
int i = 1;
while (i * 2 <= size) {
int ch = i * 2;
if (ch + 1 <= size && heap[ch + 1] > heap[ch])
++ch;
if (heap[ch] > x) {
heap[i] = heap[ch];
i = ch;
} else
break;
}
heap[i] = x;
}
long long h[maxn];
int main() {
scanf("%d%d%d", &n, &c, &f);
for (int i = 1; i <= c; ++i)
scanf("%d%d", &pg[i].a, &pg[i].b);
gkxx::merge_sort(
pg + 1, pg + c + 1,
[](const Pigeon &lhs, const Pigeon &rhs) -> bool { return lhs.a > rhs.a; });
long long sum = 0;
int k = (n - 1) / 2;
for (int i = c; i >= c - k + 1; --i) {
sum += pg[i].b;
push(pg[i].b);
}
for (int i = c - k; i >= k + 1; --i) {
h[i] = sum;
if (pg[i].b < heap[1]) {
sum -= heap[1];
pop();
sum += pg[i].b;
push(pg[i].b);
}
}
size = 0;
sum = 0;
for (int i = 1; i <= k; ++i) {
sum += pg[i].b;
push(pg[i].b);
}
for (int i = k + 1; i <= c - k; ++i) {
if (h[i] + sum + pg[i].b <= f) {
printf("%d\n", pg[i].a);
return 0;
}
if (pg[i].b < heap[1]) {
sum -= heap[1];
pop();
sum += pg[i].b;
push(pg[i].b);
}
}
puts("-1");
return 0;
}
关于这题还有一个有意思的事:2018年我在省常中集训的时候做到了这道题,但当时题目并不保证 n n n是奇数,并且定义当 n n n是偶数时中位数为中间两个数的平均数。事实上偶数的情况要比奇数复杂一些,当时一些人就想:要不我就写个奇数的情况吧,应该能混到一部分分数。最后只写了奇数的人得了零分,因为测试数据全是偶数…