CS101 2021Fall PA1,2 题解

CS101 2021Fall PA1,2 题解

关于PA (Programming Assignment)

鉴于PA的题目都没有在课上或讨论课上讲解过,piazza上也没有发过题解,GKxx打算简单整理一下这些题目的题解。对于其中的部分题目,我会尽可能多写几个能想到的算法,并提供一部分代码。

水平有限,这份题解仅供参考,如有错误欢迎指出。如果对某些题目有同学有更好的做法,请直接发,不要联系我,我懒得再看了。

建议以后的CS101可以在每次PA结束以后开放一个题解展示平台,并安排一两名助教审核提交的题解,可以参考洛谷的题解审核标准。大多数OIer都有写题解的习惯,我相信只要开放这个平台就会有人主动交题解的,都不用助教或老师自己干。

一些约定

若无特殊说明,所有下标均从 1 1 1开始;但如果是标准库容器或std::string的下标,则从 0 0 0开始。

代码是C++14的。

PA1

1001

题意:给定二维坐标平面上的 N N N个点和两个圆心 O 1 , O 2 O_1,O_2 O1,O2,需要求出两个半径 R 1 , R 2 R_1,R_2 R1,R2,使得每一个点都至少被一个圆覆盖到,并最小化 R 1 2 + R 2 2 R_1^2+R_2^2 R12+R22 N ⩽ 80000 N\leqslant 80000 N80000,坐标数值的绝对值在 1000 1000 1000以内。

每个点都会被这两个圆中的某一个覆盖(同时被两个圆覆盖的情况,你可以认为它只归其中一个圆管,另一个圆“碰巧”覆盖到它而已),所以圆的半径就由它需要覆盖的那些点中离圆心最远的那个决定。我们考虑将所有点按照与 O 1 O_1 O1的距离从大到小排序,枚举圆 O 1 O_1 O1所覆盖的最远的那个点是谁。假设圆 O 1 O_1 O1覆盖的最远的点是 P i P_i Pi,这意味着 P i , ⋯   , P n P_i,\cdots,P_n Pi,,Pn都已经被圆 O 1 O_1 O1覆盖了,而圆 O 2 O_2 O2的任务是覆盖 P 1 , ⋯   , P i − 1 P_1,\cdots,P_{i-1} P1,,Pi1,我们花 O ( n ) O(n) O(n)的时间求出这些点中与 O 2 O_2 O2的距离最大值,这样我们就得到了一个 O ( n 2 ) O(n^2) O(n2)的做法。

如何优化?注意,我们每次都需要求 f ( i ) = max ⁡ { d ( j ) , 1 ⩽ j < i } f(i)=\max\{d(j),1\leqslant j<i\} f(i)=max{d(j),1j<i},其中 d ( j ) = dist ⁡ ( P j , O 2 ) d(j)=\operatorname{dist}(P_j,O_2) d(j)=dist(Pj,O2) P j P_j Pj O 2 O_2 O2的距离。但实际上 f ( i ) = max ⁡ { f ( i − 1 ) , d ( i ) } f(i)=\max\{f(i-1),d(i)\} f(i)=max{f(i1),d(i)},它可以根据上一次的答案 O ( 1 ) O(1) O(1)地得出,不需要花 O ( n ) O(n) O(n)的时间重新求。这个 f f f不需要开数组存,因为你每次只使用上一次的答案,反复更新一个变量就好了。

这样在排序之后时间复杂度就为 O ( n ) O(n) O(n),而排序是整个算法的时间复杂度瓶颈,用常见的归并排序或快速排序可以做到 O ( n log ⁡ n ) O(n\log n) O(nlogn)

#include <cmath>
#include <cstdio>
#include <functional> // for std::less
#include <iterator>   // for std::iterator_traits, std::distance, std::next

namespace gkxx {

template <typename ForwardIterator, typename Less>
void __inplace_merge(
    ForwardIterator begin, ForwardIterator mid, ForwardIterator end,
    typename std::iterator_traits<ForwardIterator>::difference_type dist,
    Less less) {
  ForwardIterator i = begin, j = mid;
  using value_type = typename std::iterator_traits<ForwardIterator>::value_type;
  value_type *tmp = new value_type[dist](), *k = tmp;
  while (i != mid && j != end) {
    if (less(*i, *j))
      *k++ = *i++;
    else
      *k++ = *j++;
  }
  while (i != mid)
    *k++ = *i++;
  while (j != end)
    *k++ = *j++;
  k = tmp;
  while (begin != end)
    *begin++ = *k++;
  delete[] tmp;
}

template <typename ForwardIterator, typename Less>
void merge_sort(ForwardIterator begin, ForwardIterator end, Less less) {
  auto dist = std::distance(begin, end);
  if (dist <= 1)
    return;
  ForwardIterator mid = std::next(begin, dist / 2);
  merge_sort(begin, mid, less);
  merge_sort(mid, end, less);
  __inplace_merge(begin, mid, end, dist, less);
}

template <typename ForwardIterator>
inline void merge_sort(ForwardIterator begin, ForwardIterator end) {
  merge_sort(begin, end, std::less<void>());
}

} // namespace gkxx

constexpr int maxn = 80007;

struct Point {
  int x, y;
};

Point O1, O2, p[maxn];
int n;

inline int sqr_dist(const Point &a, const Point &b) {
  int dx = a.x - b.x, dy = a.y - b.y;
  return dx * dx + dy * dy;
}

int main() {
  scanf("%d%d%d%d", &O1.x, &O1.y, &O2.x, &O2.y);
  scanf("%d", &n);
  for (int i = 1; i <= n; ++i)
    scanf("%d%d", &p[i].x, &p[i].y);
  gkxx::merge_sort(p + 1, p + n + 1,
                   [](const Point &a, const Point &b) -> bool {
                     return sqr_dist(a, O1) > sqr_dist(b, O1);
                   });
  // First consider the case where R1 = 0.
  int ans = 0;
  for (int i = 1; i <= n; ++i) {
    int sqr_r2 = sqr_dist(p[i], O2);
    if (sqr_r2 > ans)
      ans = sqr_r2;
  }
  int prev_max = 0;
  for (int i = 1; i <= n; ++i) {
    int sqr_r1 = sqr_dist(p[i], O1);
    if (sqr_r1 + prev_max < ans)
      ans = sqr_r1 + prev_max;
    int d = sqr_dist(p[i], O2);
    if (d > prev_max)
      prev_max = d;
  }
  printf("%d\n", ans);
  return 0;
}

提醒一下:写快速排序要随机选pivot,你写取三个数的中位数的那种选法也还好,但有些人每次都取最后一个数,就很容易被卡成worst case。

1002

题意:给定 n n n个字符串,求出最长公共前缀(LCP)和最长公共后缀(LCS)。“公共前缀”和“公共后缀”的长度必须至少为 2 2 2。长度相等的情况下找出现次数最多的,出现次数仍然相等的情况下找字典序最小的。 n ⩽ 50000 n\leqslant 50000 n50000,单个字符串长度不超过 10 10 10,字符集为小写字母。

字符串长度不超过 10 10 10,意味着你可以肆无忌惮地遍历字符串,拷贝字符串,而不用担心时间开销,所以搞什么Trie树搞什么Hash table,直接暴力啊。

这题做法五花八门,我说一个我自己写的。先说LCP:把所有字符串按字典序从小到大排序,则答案一定是某两个相邻的字符串的LCP。求两个字符串的LCP只要暴力枚举前缀判断是否相等即可。我们为每一对相邻的字符串 s i , s i + 1 s_i,s_{i+1} si,si+1求出LCP,记为 t i t_i ti,则我们只要找出 t 1 , ⋯   , t n − 1 t_1,\cdots,t_{n-1} t1,,tn1这些串中最长的那个,长度相等的情况下出现次数最多的那个,出现次数仍相等的情况下字典序最小的那个。

实现中,我们可以记录每一个LCP出现的次数,如果当前 s i s_i si s i + 1 s_{i+1} si+1的LCP恰好和 s i − 1 s_{i-1} si1 s i s_i si的LCP相同(长度相等就可以判断相同),就递增上一个LCP的出现次数。我们可以用一个std::pair<std::string, int>来保存一个LCP及其出现次数,那么最终只要在所有pair中找一个最优的,这跟在给定数组中找最大值没有本质区别。

对于LCS,我们可以把所有串都反过来,然后求LCP。但务必注意,比较字典序的时候要用反转之前的串比较。

1003

题意:给定 n n n个数 a 1 , ⋯   , a n a_1,\cdots,a_n a1,,an,计算有多少对 ( i , j ) (i,j) (i,j)满足 i , j ∈ [ 1 , n ] ∩ Z , i < j i,j\in[1,n]\cap\Z,i<j i,j[1,n]Z,i<j,且 ∀ k ∈ ( i , j ) ∩ Z , a k ⩽ min ⁡ { a i , a j } \forall k\in(i,j)\cap\Z,a_k\leqslant\min\{a_i,a_j\} k(i,j)Z,akmin{ai,aj}。(即给出 n n n位同学的身高,请问有多少对同学能互相看见)

算法1:单调栈

考虑枚举右边的同学 j j j,计算左边有多少个同学 i i i能和 j j j互相看见。枚举 j j j的复杂度已经 O ( n ) O(n) O(n),所以计算 i i i的数量的过程必须足够快。有一个显然与此有关的问题:假设 f ( j ) f(j) f(j)表示 j j j左边第一个比 j j j高的同学的编号,那么显然在 f ( j ) f(j) f(j)左边的同学都没法看见 j j j。你能否快速求出 f ( j ) f(j) f(j)?你不能从 j − 1 j-1 j1 1 1 1一个一个检查,因为这会使得求单个 f ( j ) f(j) f(j)的复杂度为 O ( j ) O(j) O(j),总复杂度就为 O ( n 2 ) O(n^2) O(n2),太慢。

f k ( j ) = f ( f k − 1 ( j ) ) , f 0 ( j ) = j f^k(j)=f(f^{k-1}(j)),f^0(j)=j fk(j)=f(fk1(j)),f0(j)=j。假如已知 f ( 1 ) , ⋯   , f ( j − 1 ) f(1),\cdots,f(j-1) f(1),,f(j1),我们发现 f ( j ) ∈ { j − 1 , f ( j − 1 ) , f 2 ( j − 1 ) , ⋯   } f(j)\in\left\{j-1,f(j-1),f^2(j-1),\cdots\right\} f(j){j1,f(j1),f2(j1),}(为什么?),并且 a j − 1 < a f ( j − 1 ) < a f 2 ( j − 1 ) < ⋯ a_{j-1}<a_{f(j-1)}<a_{f^2(j-1)}<\cdots aj1<af(j1)<af2(j1)<。我们可以用一个栈 s s s来存储这些位置:从栈顶到栈底依次为 j − 1 , f ( j − 1 ) , f 2 ( j − 1 ) , ⋯ j-1,f(j-1),f^2(j-1),\cdots j1,f(j1),f2(j1),,那么查询 f ( j ) f(j) f(j)就只要在 s s s中从栈顶开始逐个检查即可。下图是 j = 11 j=11 j=11时的情况,蓝色的柱子就是每个同学,红色折线连接的点就是这个栈所保存的位置编号。

当你求出了 f ( j ) f(j) f(j)之后,为了下一次求 f ( j + 1 ) f(j+1) f(j+1),你还要正确地维护这个栈,也就是说得把这个栈变成 j , f ( j ) , f 2 ( j ) , ⋯ j,f(j),f^2(j),\cdots j,f(j),f2(j),。如何维护?我们可以逐个检查栈顶元素 f k ( j − 1 ) , k = 0 , 1 , ⋯ f^k(j-1),k=0,1,\cdots fk(j1),k=0,1,,只要 a f k ( j − 1 ) ⩽ a j a_{f^k(j-1)}\leqslant a_j afk(j1)aj就把 f k ( j − 1 ) f^k(j-1) fk(j1)弹出,直到栈顶所标示的位置上的同学比 j j j高(这也恰好就是 f ( j ) f(j) f(j))就停止,然后把 j j j放到栈顶即可。在上图的例子中,我们就要把 10 10 10 8 8 8两位同学弹栈,然后将 11 11 11加入栈顶。

由于 a j < a f ( j ) < a f 2 ( j ) < ⋯ a_j<a_{f(j)}<a_{f^2(j)}<\cdots aj<af(j)<af2(j)<,这个栈被称为单调栈。以下是求 f f f数组的代码:

top = 0;
for (int j = 1; j <= n; ++j) {
  while (top > 0 && a[s[top]] <= a[j])
    --top;
  f[j] = s[top];
  s[++top] = j;
}

这时有人就要问了:每次都这样暴力地从顶向下检查,最坏情况下岂不是要检查 O ( n ) O(n) O(n)个元素,那总复杂度不就 O ( n 2 ) O(n^2) O(n2)了吗?但是请注意:你的算法的复杂度取决于执行次数最多的那条语句被执行了多少次,在这里其实也就是弹栈操作--top;,然而 1 1 1 n n n的每一个元素都只会进栈一次,也就至多出栈一次,所以弹栈不会超过 n n n次,总的复杂度就是 O ( n ) O(n) O(n),虽然它看起来是个双重循环。

如果所有的 a i a_i ai都不相等,我们惊奇地发现:那些能和 j j j互相看见的同学正是 j − 1 , f ( j − 1 ) , f 2 ( j − 1 ) , ⋯ j-1,f(j-1),f^2(j-1),\cdots j1,f(j1),f2(j1),,也就是在求 f ( j ) f(j) f(j)的时候被弹栈的那些位置,所以我们在弹栈的时候顺带统计一下即可。但是题目并不保证所有 a i a_i ai都不相等,这时会出现一个问题:假设 f k ( j − 1 ) f^k(j-1) fk(j1) j j j能看见,而在 f k + 1 ( j − 1 ) f^{k+1}(j-1) fk+1(j1) f k ( j − 1 ) f^k(j-1) fk(j1)之间存在一个和 f k ( j − 1 ) f^k(j-1) fk(j1)一样高的同学 t t t,那么 t t t不会存在于栈中,但是 t t t也能和 j j j看见。为了让这样的 t t t也存在于栈中,你得把弹栈的条件由 a f k ( j − 1 ) ⩽ a j a_{f^k(j-1)}\leqslant a_j afk(j1)aj改成 a f k ( j − 1 ) < a j a_{f^k(j-1)}<a_j afk(j1)<aj,也就是a[s[top]] < a[i]。但是在统计答案的时候,仍然是找到严格高于 j j j的同学为止,如果你仍然从栈顶开始一个一个找,就会让复杂度变成 O ( n 2 ) O(n^2) O(n2),因为最坏情况下所有同学的高度都相等,所有同学入栈了之后都不会出栈,每次查找都得一直找到第一个。(实测这种做法可以通过,说明数据比较水)

一种解决方法是使用二分查找。二分是一个极其重要的思想,但整个CS101课程中始终没有提及。二分查找的复杂度为 O ( log ⁡ n ) O(\log n) O(logn),那么总复杂度即为 O ( n log ⁡ n ) O(n\log n) O(nlogn)。查找过程中要小心 f ( j ) f(j) f(j)不存在的情况,即 j j j左边的所有同学都不比 j j j高,我的处理方式是假定 a 0 = + ∞ a_0=+\infty a0=+。代码如下:

#include <climits>
#include <cstdio>

constexpr int maxn = 5e5 + 7;
constexpr int inf = 1e9;

int a[maxn], n;
int stk[maxn], top;

int main() {
  scanf("%d", &n);
  for (int i = 1; i <= n; ++i)
    scanf("%d", a + i);
  int ans = 0;
  a[0] = inf;
  stk[++top] = 1;
  for (int i = 2; i <= n; ++i) {
    int left = 0, right = top, pos;
    while (left <= right) {
      int mid = (left + right) / 2;
      if (a[stk[mid]] <= a[i])
        right = mid - 1;
      else {
        pos = mid;
        left = mid + 1;
      }
    }
    if (!pos)
      ++pos;
    int delta = top - pos + 1;
    ans += delta;
    while (a[stk[top]] < a[i] && top)
      --top;
    stk[++top] = i;
  }
  printf("%d\n", ans);
  return 0;
}

第二种解决方法是考虑给栈里高度相等的同学记录重复次数,在入栈的时候如果发现与栈顶一样高就不要入栈,而是递增其重数。但实际上你需要记录的是重数的前缀和,因为查询答案时需要计算几个同学的重数之和,可以用两个前缀和相减得出。这样可以做到 O ( n ) O(n) O(n),代码如下:

#include <climits>
#include <cstdio>

constexpr int maxn = 5e5 + 7;
constexpr int inf = 1e9;

int a[maxn], n;
int stk[maxn], top;
int sum[maxn];

int main() {
  scanf("%d", &n);
  for (int i = 1; i <= n; ++i)
    scanf("%d", a + i);
  int ans = 0;
  a[0] = inf;
  stk[++top] = 1;
  sum[1] = 1;
  for (int i = 2; i <= n; ++i) {
    int pos = top;
    while (a[stk[pos]] <= a[i] && pos)
      --pos;
    int delta = pos ? sum[top] - sum[pos] + 1 : sum[top];
    ans += delta;
    while (a[stk[top]] < a[i] && top)
      --top;
    if (a[stk[top]] == a[i])
      ++sum[top];
    else {
      stk[++top] = i;
      sum[top] = sum[top - 1] + 1;
    }
  }
  printf("%d\n", ans);
  return 0;
}
算法2:分治

假设我们要统计从第 ℓ \ell 个同学到第 r r r个同学这一段上能互相看见的同学的数量,我们将区间 [ l , r ] [l,r] [l,r]拆分成 [ l , m ] [l,m] [l,m] [ m + 1 , r ] [m+1,r] [m+1,r],其中 m = ⌊ ℓ + r 2 ⌋ m=\left\lfloor\frac{\ell+r}{2}\right\rfloor m=2+r。能够互相看见的两人 ( i , j ) (i,j) (i,j)要么都来自 [ l , m ] [l,m] [l,m],要么都来自 [ m + 1 , r ] [m+1,r] [m+1,r],要么 i ∈ [ l , m ] , j ∈ [ m + 1 , r ] i\in[l,m],j\in[m+1,r] i[l,m],j[m+1,r]。对于前两种情况只要递归调用即可。对于第三种情况,使用两个指针 i , j i,j i,j,一开始分别指向 m m m m + 1 m+1 m+1(这里的“指针”仅仅是一种形象的表述,在实现中就是两个下标,并不是C/C++中的那种指针)。

  • 假如 a i < a j a_i<a_j ai<aj,并且 i i i j j j能互相看见,考虑将指针 i i i左移,找到 i i i左边第一个比 i i i高的同学 k k k,那么在 ( k , i ] (k,i] (k,i]中与 i i i一样高的同学都能与 j j j看见,将它们计入答案,并将 i i i赋值为 k k k
  • 假如 a i > a j a_i>a_j ai>aj并且 i i i j j j能互相看见,情况是类似的。
  • 如果 a i = a j a_i=a_j ai=aj并且 i i i j j j能互相看见,我们需要同时左移 i i i和右移 j j j,找到 i i i左边第一个比 i i i高的同学 k 1 k_1 k1,以及 j j j右边第一个比 j j j高的同学 k 2 k_2 k2,那么在 ( k 1 , i ] (k_1,i] (k1,i]中和在 ( j , k 2 ] (j,k_2] (j,k2]中那些与 i , j i,j i,j一样高的同学都可以互相看见,所以我们分别统计出这两个区间中身高等于 a i a_i ai的同学数量 n 1 , n 2 n_1,n_2 n1,n2,那么就有 n 1 × n 2 n_1\times n_2 n1×n2对同学能互相看见,要计入答案。但同时, k 2 k_2 k2也能和左边的这 n 1 n_1 n1位同学看见, k 1 k_1 k1也能和右边的这 n 2 n_2 n2位同学看见,因此还要再加上 n 1 + n 2 n_1+n_2 n1+n2。注意, k 1 k_1 k1 k 2 k_2 k2也能互相看见,但这会在下一轮中被统计,不要重复统计。

因此统计 i ∈ [ l , m ] , j ∈ [ m + 1 , r ] i\in[l,m],j\in[m+1,r] i[l,m],j[m+1,r]的情况,只需将两个指针分别从中点移到两端,时间复杂度 O ( n ) O(n) O(n),总的时间复杂度就是 O ( n log ⁡ n ) O(n\log n) O(nlogn)。在实现中,我在左端点左边和右端点右边各设了一个巨人(身高为无穷大的同学),这样可以一定程度上减少边界考虑的负担。

#include <cstdio>

constexpr int maxn = 5e5 + 7;
constexpr int inf = 1e9;

int a[maxn], n;

int solve(int l, int r) {
  if (r - l + 1 <= 1)
    return 0;
  int mid = (l + r) / 2;
  int ans = 0;
  int i = mid, j = mid + 1;
  int tmpleft = a[l - 1], tmpright = a[r + 1];
  a[l - 1] = a[r + 1] = inf;
  while (i >= l && j <= r) {
    if (a[i] < a[j]) {
      int k = i;
      while (a[k] <= a[i]) {
        if (a[k] == a[i])
          ++ans;
        --k;
      }
      i = k;
    } else if (a[i] > a[j]) {
      int k = j;
      while (a[k] <= a[j]) {
        if (a[k] == a[j])
          ++ans;
        ++k;
      }
      j = k;
    } else {
      int n1 = 0, n2 = 0, k = i;
      while (a[k] <= a[i]) {
        if (a[k] == a[i])
          ++n1;
        --k;
      }
      i = k;
      k = j;
      while (a[k] <= a[j]) {
        if (a[k] == a[j])
          ++n2;
        ++k;
      }
      j = k;
      ans += n1 * n2;
      if (i >= l)
        ans += n2;
      if (j <= r)
        ans += n1;
    }
  }

  a[l - 1] = tmpleft;
  a[r + 1] = tmpright;
  return ans + solve(l, mid) + solve(mid + 1, r);
}

int main() {
  scanf("%d", &n);
  for (int i = 1; i <= n; ++i)
    scanf("%d", a + i);
  printf("%d\n", solve(1, n));
  return 0;
}

PA2

2001

题意:给一个 n n n个结点的树,结点 i i i有一个权值 a i ∈ { 0 , 1 , 2 } a_i\in\{0,1,2\} ai{0,1,2}。如果一棵树中的所有非零权值都相等(都为 1 1 1或都为 2 2 2),则称这棵树为和平的。一条边被称为和平的,当且仅当将这条边删除之后得到的两棵树都是和平的。问和平边的数量。

首先说一下怎么存一棵树或者一张图。在绝大多数图算法或树算法中,最基本的操作都是遍历某个结点的所有neighbors,所以你的存储方式必须支持高效地访问每个点的neighbors。像邻接矩阵就不是一个好的存图方式,因为它必须花 O ( n ) O(n) O(n)的时间枚举点,再判断它是否与当前点相邻。通常的准则是:使用Floyd-Warshall算法时用邻接矩阵,使用Kruskal算法时就直接把边放进一个数组(因为需要给边排序),其它情况几乎总是使用邻接表。

邻接表怎么写?如果允许使用std::vector,只要给每个结点开一个std::vector存储它的邻居。下面给出存储、加边和dfs的代码,注意这里的加边是加有向边,无向边可以被视为一正一反两条有向边。

std::vector<int> G[maxn];

inline void add_edge(int x, int y) {
  G[x].push_back(y);
}

bool vis[maxn];

void dfs(int x) {
  printf("visit %d\n", x);
  vis[x] = true;
  for (auto v : G[x])
    if (!vis[v])
      dfs(v);
}

但是CS101 PA不允许使用STL。当然你可以自己造一个std::vector,它并不是很难,还能加深你对STL和其它一些C++语法知识的理解。在C++ Primer的第12、13章有介绍如何造std::vector

不允许使用STL的话,我们可以给每个结点开一个链表存放它的所有邻居,开一个数组存储每个链表的表头指针。加入一条边 ( x , y ) (x,y) (x,y)的时候,就是给第 x x x个链表增加一个元素,显然应该加在表头,所以就令新结点的next指向head[x],再令head[x]指向新结点即可。使用动态内存一定要记得释放!!!使用动态内存一定要记得释放!!!使用动态内存一定要记得释放!!!(据我所知80%的人写了new都不知道要delete,我不知道你们的CS100谁教的)

struct Node {
  int to;
  Node *next;
  Node(int x, Node *n) : to(x), next(n) {}
};
Node *head[maxn];

inline void add_edge(int x, int y) {
  Node *newnode = new Node(y, head[x]);
  head[x] = newnode;
}

bool vis[maxn];

void dfs(int x) {
  printf("visit %d\n", x);
  vis[x] = true;
  for (Node *p = head[x]; p; p = p->next) {
    int v = p->to;
    if (!vis[v])
      dfs(v);
  }
}

void clear_list(Node *p) {
  if (!p)
    return;
  clear_list(p->next);
  delete p;
}

inline void clear() {
  for (int i = 1; i <= n; ++i)
    clear_list(head[i]);
}

如果你嫌动态内存麻烦,不熟悉,容易出错,那么你可以采用下面的数组模拟链表的方式(这其实是一些人所说的“前向星”)。我们需要开一个大数组G来存放链表的结点,每次新建结点的时候就从G中拿一个没被用过的。用一个变量total来表示现在G里有多少个结点已被使用。原来的指针值变成了结点在G中的下标。由于不涉及动态内存,我们也不需要去delete。注意,G的大小是你总共需要添加的边的数量,如果是有向图就是 ∣ E ∣ |E| E,无向图则是 2 ∣ E ∣ 2|E| 2E,树是 2 ∣ V ∣ 2|V| 2V。代码中用maxe表示总边数。

struct Node {
  int to;
  int next;
};
Node G[maxe];
int head[maxn];
int total;

inline void add_edge(int x, int y) {
  G[++total].to = y;
  G[total].next = head[x];
  head[x] = total;
}

bool vis[maxn];

void dfs(int x) {
  printf("visit %d\n", x);
  vis[x] = true;
  for (int i = head[x]; i; i = G[i].next) {
    int v = G[i].to;
    if (!vis[v])
      dfs(v);
  }
}

现在来看题。首先我们将这棵树任意挑一个点(比如 1 1 1)为根,变成一棵有根树。令 f 1 ( x ) f_1(x) f1(x)表示以结点 x x x为根的子树中有多少个点的权值为 1 1 1,令 f 2 ( x ) f_2(x) f2(x)表示以结点 x x x为根的子树中有多少个点的权值为 2 2 2。我们可以通过一遍dfs来计算出所有的 f 1 ( x ) f_1(x) f1(x) f 2 ( x ) f_2(x) f2(x),因为
f 1 ( x ) = [ a x = 1 ] + ∑ v ∈ child ⁡ ( x ) f 1 ( v ) , f 2 ( x ) = [ a x = 2 ] + ∑ v ∈ child ⁡ ( x ) f 2 ( v ) . f_1(x)=[a_x=1]+\sum_{v\in\operatorname{child}(x)}f_1(v),\quad f_2(x)=[a_x=2]+\sum_{v\in\operatorname{child}(x)}f_2(v). f1(x)=[ax=1]+vchild(x)f1(v),f2(x)=[ax=2]+vchild(x)f2(v).
计算出这些信息有什么用呢?我们发现,假设我们要删除一条边 ( x , v ) (x,v) (x,v),其中 v v v x x x的儿子,删除这条边之后得到的两棵树分别是以 v v v为根的子树和除了以 v v v为根的子树之外剩下的部分。如果这两棵树都是和平的,这条边就是和平的。而一棵树是和平的,就意味着它要么不含权值为 1 1 1的点,要么不含权值为 2 2 2的点,所以只要判断f1[v] == 0 || f2[v] == 0就可以判断以 v v v为根的子树是否是和平的。剩下的部分中权值为 1 1 1的点的数量就是总数减去f1[v],所以我们在输入的时候就计算出sum1sum2,分别表示权值为 1 1 1 2 2 2的点一共有多少个,那么sum1 - f1[v] == 0 || sum2 - f2[v] == 0就判断了剩下的部分是否是和平的。整个算法只要做一遍dfs,时间复杂度 O ( n ) O(n) O(n)

#include <cstdio>

constexpr int maxn = 3e5 + 7;

struct Node {
  int to, next;
};
Node G[maxn * 2];
int total, n;
int head[maxn];
int sum1, sum2, f1[maxn], f2[maxn];
int a[maxn];
int ans;

inline void add_edge(int x, int y) {
  G[++total].to = y;
  G[total].next = head[x];
  head[x] = total;
}

void dfs(int x, int fa) {
  if (a[x] == 1)
    f1[x] = 1;
  if (a[x] == 2)
    f2[x] = 1;
  for (int i = head[x]; i; i = G[i].next) {
    int v = G[i].to;
    if (v != fa) {
      dfs(v, x);
      f1[x] += f1[v];
      f2[x] += f2[v];
      if ((f1[v] == 0 || f2[v] == 0) && (sum1 - f1[v] == 0 || sum2 - f2[v] == 0))
        ++ans;
    }
  }
}

int main() {
  scanf("%d", &n);
  for (int i = 1; i <= n; ++i) {
    scanf("%d", a + i);
    if (a[i] == 1)
      ++sum1;
    if (a[i] == 2)
      ++sum2;
  }
  for (int i = 1; i < n; ++i) {
    int x, y;
    scanf("%d%d", &x, &y);
    add_edge(x, y);
  }
  dfs(1, 0);
  printf("%d\n", ans);
  return 0;
}

2002

题意:给定 n n n个正整数 s 1 , ⋯   , s n s_1,\cdots,s_n s1,,sn,称一个整数区间 [ ℓ , r ] [\ell,r] [,r]为“和谐的”,指的是存在整数 t ⩾ 2 t\geqslant 2 t2使得
s ℓ ≡ s ℓ + 1 ≡ ⋯ ≡ s r ( m o d t ) . s_\ell\equiv s_{\ell+1}\equiv\cdots\equiv s_r\pmod t. ss+1sr(modt).
求长度最长的和谐区间。

如果区间 [ ℓ , r ] [\ell,r] [,r]是和谐的,意味着
∃ t ⩾ 2 , ∣ s ℓ − s ℓ + 1 ∣ ≡ ∣ s ℓ + 1 − s ℓ + 2 ∣ ≡ ⋯ ≡ ∣ s r − 1 − s r ∣ ≡ 0 ( m o d t ) . \exists t\geqslant 2,\quad \left|s_\ell-s_{\ell+1}\right|\equiv\left|s_{\ell+1}-s_{\ell+2}\right|\equiv\cdots\equiv\left|s_{r-1}-s_r\right|\equiv 0\pmod t. t2,ss+1s+1s+2sr1sr0(modt).
d i = ∣ s i − 1 − s i ∣ d_i=\left|s_{i-1}-s_i\right| di=si1si,则
∃ t ⩾ 2 , d ℓ + 1 ≡ d ℓ + 2 ≡ ⋯ ≡ d r ≡ 0 ( m o d t ) , \exists t\geqslant 2,\quad d_{\ell+1}\equiv d_{\ell+2}\equiv\cdots\equiv d_r\equiv 0\pmod t, t2,d+1d+2dr0(modt),
所以
∃ t ⩾ 2 , ∀ i ∈ [ ℓ + 1 , r ] , t ∣ d i , \exists t\geqslant 2,\forall i\in[\ell+1,r],t\mid d_i, t2,i[+1,r],tdi,
那么
∃ t ⩾ 2 , t ∣ gcd ⁡ ( d ℓ + 1 , d ℓ + 2 , ⋯   , d r ) ,    ⟺    gcd ⁡ ( d ℓ + 1 , d ℓ + 2 , ⋯   , d r ) ≠ 1. \exists t\geqslant 2,t\mid\gcd(d_{\ell+1},d_{\ell+2},\cdots,d_r),\iff \gcd(d_{\ell+1},d_{\ell+2},\cdots,d_r)\neq1. t2,tgcd(d+1,d+2,,dr),gcd(d+1,d+2,,dr)=1.
注意,我这里写的是 ≠ 1 \neq 1 =1而不是 ⩾ 2 \geqslant 2 2,因为可能会出现 gcd ⁡ = 0 \gcd=0 gcd=0的情况,这也是符合的。 于是我们要做的就是在 d 2 , ⋯   , d n d_2,\cdots,d_n d2,,dn上找到最长的区间,使得这个区间里的元素不互素。注意一个细节:当“和谐区间”长度为 2 2 2的时候,比方说是 [ i , i + 1 ] [i,i+1] [i,i+1],对应的 d d d序列上的不互素区间长度为 1 1 1,相当于是要判断 gcd ⁡ ( d i ) ≠ 1 \gcd(d_i)\neq 1 gcd(di)=1,而一个数的“ gcd ⁡ \gcd gcd”该怎么定义呢?稍加思索,发现定义 gcd ⁡ ( x ) = x \gcd(x)=x gcd(x)=x是合理的(为什么?)。此外, d d d序列中的元素可能会有 0 0 0,但数学一般不考虑 0 0 0与其它数的 gcd ⁡ \gcd gcd。我们可以定义 gcd ⁡ ( x , 0 ) = x \gcd(x,0)=x gcd(x,0)=x(为什么?)。

我们考虑枚举左端点 ℓ \ell ,尝试求出最远的右端点 r r r,使得 gcd ⁡ ( d ℓ , ⋯   , d r ) ≠ 1 \gcd(d_\ell,\cdots,d_r)\neq 1 gcd(d,,dr)=1(自然 gcd ⁡ ( d ℓ , ⋯   , d r + 1 ) = 1 \gcd(d_\ell,\cdots,d_{r+1})=1 gcd(d,,dr+1)=1)。注意,随着右端点 r r r向右移动,区间里的元素增加, gcd ⁡ \gcd gcd是单调不增的。对于一个固定的左端点 ℓ \ell 来说,我们自然可以将右端点 r r r ℓ \ell 开始逐步右移,并注意到
gcd ⁡ ( d ℓ , ⋯   , d r ) = gcd ⁡ ( gcd ⁡ ( d ℓ , ⋯   , d r − 1 ) , d r ) , (2.1) \gcd(d_\ell,\cdots,d_r)=\gcd(\gcd(d_\ell,\cdots,d_{r-1}),d_r),\tag{2.1} gcd(d,,dr)=gcd(gcd(d,,dr1),dr),(2.1)
于是花 O ( n log ⁡ A ) O(n\log A) O(nlogA)的时间求出最远的右端点 r r r,这里 O ( log ⁡ A ) O(\log A) O(logA)是求 gcd ⁡ \gcd gcd的时间复杂度,其中 A A A d d d的最大值。但是由于枚举左端点已经 O ( n ) O(n) O(n),总的复杂度就达到了 O ( n 2 log ⁡ A ) O(n^2\log A) O(n2logA),这不能通过。

假如对于某个左端点 ℓ \ell ,我们已经知道它对应的最远右端点是 r = r ( ℓ ) r=r(\ell) r=r(),我们有没有办法快一点求出下一个左端点 ℓ + 1 \ell+1 +1对应的右端点 r ( ℓ + 1 ) r(\ell+1) r(+1)呢?我们发现当区间从 [ ℓ , r ] [\ell,r] [,r]变成 [ ℓ + 1 , r ] [\ell+1,r] [+1,r]时,由于区间长度缩短, gcd ⁡ \gcd gcd单调不减,所以必然有 r ( ℓ + 1 ) ⩾ r ( ℓ ) r(\ell+1)\geqslant r(\ell) r(+1)r(),也就是说我们没必要再从 ℓ + 1 \ell+1 +1开始往右找,而是直接从 r ( ℓ ) r(\ell) r()开始。但这时问题来了,递推式 ( 2.1 ) (2.1) (2.1)给出了当我们向区间中增加一个元素时快速求出 gcd ⁡ \gcd gcd的方式,但是左端点从 ℓ \ell 变成 ℓ + 1 \ell+1 +1时是从区间中删除一个数,这时该如何快速求出 [ ℓ + 1 , r ] [\ell+1,r] [+1,r] gcd ⁡ \gcd gcd

算法1:查询区间 gcd ⁡ \gcd gcd

这个算法需要一个超出本课程范围的知识,如果你不想学可以跳到算法2。

现在有这样一个问题:给定一个非负整数序列 d 1 , ⋯   , d n d_1,\cdots,d_n d1,,dn,现在有 q q q次询问,第 i i i次询问是一个区间 [ ℓ i , r i ] [\ell_i,r_i] [i,ri],请你快速回答 gcd ⁡ ( d ℓ i , d ℓ i + 1 , ⋯   , d r i ) \gcd\left(d_{\ell_i},d_{\ell_i+1},\cdots,d_{r_i}\right) gcd(di,di+1,,dri),即这个区间上的 gcd ⁡ \gcd gcd n ⩽ 1 0 5 , q ⩽ 1 0 5 n\leqslant 10^5,q\leqslant 10^5 n105,q105

显然你不能花 O ( n log ⁡ A ) O(n\log A) O(nlogA)的时间去遍历区间来回答单次询问,我们得事先预处理一些答案。如果开一个二维数组 f ( ℓ , r ) = gcd ⁡ ( d ℓ , d ℓ + 1 , ⋯   , d r ) f(\ell,r)=\gcd(d_\ell,d_{\ell+1},\cdots,d_r) f(,r)=gcd(d,d+1,,dr),在回答所有询问之前就把这些答案都预处理出来,那我们就能做到 O ( 1 ) O(1) O(1)回答每次询问了。但这样肯定不行,因为这个二维数组的空间复杂度达到了 O ( n 2 ) O(n^2) O(n2),并且预处理这个二维数组的时间复杂度也达到了 O ( n 2 log ⁡ A ) O(n^2\log A) O(n2logA)。这个预处理之所以失败,是因为我们把所有的工作量都压在了预处理阶段;而如果每次询问都花 O ( n log ⁡ A ) O(n\log A) O(nlogA)的时间遍历区间,相当于把所有的工作量都放在了询问阶段。这是两个极端。

那么折中的思想就是:预处理一部分答案,想办法在查询的时候利用这一部分答案来获得想要的答案。ST表(稀疏表,Sparse Table)正是基于这一思想:我们令 f ( i , j ) f(i,j) f(i,j)表示左端点为 i i i,长度为 2 j 2^j 2j的区间上的 gcd ⁡ \gcd gcd。由于区间长度是 2 j ⩽ n 2^j\leqslant n 2jn,所以这个二维数组的空间复杂度是 O ( n log ⁡ n ) O(n\log n) O(nlogn),可以接受。那怎么计算呢?我们把区间 [ i , i + 2 j − 1 ] \left[i,i+2^j-1\right] [i,i+2j1]切成 [ i , i + 2 j − 1 − 1 ] \left[i,i+2^{j-1}-1\right] [i,i+2j11] [ i + 2 j − 1 , i + 2 j − 1 ] \left[i+2^{j-1},i+2^j-1\right] [i+2j1,i+2j1]两段长度均为 2 j − 1 2^{j-1} 2j1的区间,那么 [ i , i + 2 j − 1 ] \left[i,i+2^j-1\right] [i,i+2j1] gcd ⁡ \gcd gcd就是这两个子区间的 gcd ⁡ \gcd gcd gcd ⁡ \gcd gcd,这其实可以看做一种动态规划,其转移方程是
f ( i , j ) = gcd ⁡ ( f ( i , j − 1 ) , f ( i + 2 j − 1 , j − 1 ) ) . f(i,j)=\gcd\left(f(i,j-1),f\left(i+2^{j-1},j-1\right)\right). f(i,j)=gcd(f(i,j1),f(i+2j1,j1)).
所以我们可以花 O ( log ⁡ A ) O(\log A) O(logA)的时间由 f ( i , j − 1 ) f(i,j-1) f(i,j1) f ( i + 2 j − 1 , j − 1 ) f\left(i+2^{j-1},j-1\right) f(i+2j1,j1)递推出 f ( i , j ) f(i,j) f(i,j)。注意,计算 2 j 2^j 2j O ( 1 ) O(1) O(1)的,只要用位运算1 << j即可。还要注意计算的顺序,你必须保证在计算 f ( i , j ) f(i,j) f(i,j)的时候 f ( ∗ , j − 1 ) f(*,j-1) f(,j1)已经算过,下面是一个典型的错误写法:

for (int i = 1; i <= n; ++i)
  for (int j = 1; i + (1 << j) - 1 <= n; ++j)
    f[i][j] = gcd(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);

你必须这样写:

for (int j = 1; (1 << j) <= n; ++j)
  for (int i = 1; i + (1 << j) - 1 <= n; ++j)
    f[i][j] = gcd(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);

并且在此之前先计算base case

for (int i = 1; i <= n; ++i)
  f[i][0] = d[i];

这样我们就在 O ( n log ⁡ n log ⁡ A ) O(n\log n\log A) O(nlognlogA)的时间内预处理出了这张ST表。现在考虑如何处理一个询问 [ ℓ , r ] [\ell,r] [,r]。我们找到最大的 k k k,满足 2 k ⩽ r − ℓ + 1 2^k\leqslant r-\ell+1 2kr+1,那么 [ ℓ , ℓ + 2 k − 1 ] ⊆ [ ℓ , r ] , [ r − 2 k + 1 , r ] ⊆ [ ℓ , r ] \left[\ell,\ell+2^k-1\right]\subseteq[\ell,r],\left[r-2^k+1,r\right]\subseteq[\ell,r] [,+2k1][,r],[r2k+1,r][,r],并且 [ ℓ , ℓ + 2 k − 1 ] ∪ [ r − 2 k + 1 , r ] = [ ℓ , r ] \left[\ell,\ell+2^k-1\right]\cup\left[r-2^k+1,r\right]=[\ell,r] [,+2k1][r2k+1,r]=[,r],所以区间 [ ℓ , r ] [\ell,r] [,r]上的 gcd ⁡ \gcd gcd就是
gcd ⁡ ( f ( ℓ , k ) , f ( r − 2 k + 1 , k ) ) . \gcd\left(f(\ell,k),f(r-2^k+1,k)\right). gcd(f(,k),f(r2k+1,k)).
这里的 k k k显然等于 ⌊ log ⁡ 2 ( r − ℓ + 1 ) ⌋ \left\lfloor\log_2(r-\ell+1)\right\rfloor log2(r+1)。这个计算可以调用cmath里的相关函数来完成,但我们也可以预处理这些数:设 g ( i ) = ⌊ log ⁡ 2 i ⌋ g(i)=\left\lfloor\log_2i\right\rfloor g(i)=log2i,则 g ( i ) = g ( ⌊ i / 2 ⌋ ) + 1 g(i)=g\left(\lfloor i/2\rfloor\right)+1 g(i)=g(i/2)+1,可以 O ( n ) O(n) O(n)预处理出来。于是我们做到了 O ( log ⁡ A ) O(\log A) O(logA)回答单次询问。

inline int query(int l, int r) {
  int k = g[r - l + 1];
  return gcd(f[l][k], f[r - (1 << k) + 1][k]);
}

有了这个核武器,前面的算法就可以实现了:设置一个左端点 ℓ \ell 和它对应的最远右端点 r r r,每次将左端点 ℓ \ell 右移的时候,尝试右移右端点 r r r直到 gcd ⁡ ( d ℓ , ⋯   , d r ) ≠ 1 \gcd(d_\ell,\cdots,d_r)\neq1 gcd(d,,dr)=1 gcd ⁡ ( d ℓ , ⋯   , d r + 1 ) = 1 \gcd(d_\ell,\cdots,d_{r+1})=1 gcd(d,,dr+1)=1,这个过程中每次求 gcd ⁡ \gcd gcd都只要 O ( log ⁡ A ) O(\log A) O(logA)的时间。由于 ℓ \ell r r r都是单向移动的,每个位置都至多被 ℓ \ell 经过一次、被 r r r经过一次,所以时间复杂度 O ( n log ⁡ A ) O(n\log A) O(nlogA)。再加上预处理ST表的复杂度,总的复杂度 O ( n log ⁡ n log ⁡ A + n log ⁡ A ) = O ( n log ⁡ n log ⁡ A ) O(n\log n\log A+n\log A)=O(n\log n\log A) O(nlognlogA+nlogA)=O(nlognlogA)。实现的时候,移动左右端点的部分要格外小心,细节容易写错。另外, d d d上的不互素区间长度 + 1 +1 +1才是 s s s上的和谐区间长度。

#include <cstdio>
#include <cmath>

constexpr int maxn = 2e3 + 7;

int s[maxn], d[maxn], n, T;
int f[maxn][1 << 14], g[maxn];

int gcd(int a, int b) {
  return b ? gcd(b, a % b) : a;
}

inline void init() {
  for (int i = 2; i <= n; ++i)
    g[i] = g[i / 2] + 1;
  for (int i = 1; i <= n; ++i)
    f[i][0] = d[i];
  for (int j = 1; 1 << j <= n; ++j)
    for (int i = 1; i + (1 << j) - 1 <= n; ++i)
      f[i][j] = gcd(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}

inline int query(int l, int r) {
  int k = g[r - l + 1];
  return gcd(f[l][k], f[r - (1 << k) + 1][k]);
}

int main() {
  scanf("%d", &T);
  while (T--) {
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) {
      scanf("%d", s + i);
      d[i] = abs(s[i] - s[i - 1]);
    }
    init();
    int ans = 1;
    for (int l = 2, r = 1; l <= n; ++l) {
      if (r < l - 1)
        r = l - 1;
      while (r < n && query(l, r + 1) != 1)
        ++r;
      if (r - l + 2 > ans)
        ans = r - l + 2;
    }
    printf("%d\n", ans);
  }
  return 0;
}

此外,这个区间查询问题还可以用线段树实现,可以做到 O ( n log ⁡ A ) O(n\log A) O(nlogA)预处理, O ( log ⁡ n log ⁡ A ) O(\log n\log A) O(lognlogA)回答单次询问。有兴趣的同学可以自行查找资料学习。线段树是一个非常强大的数据结构,我会在后面的题目中再次提到。

算法2:分治

考虑求出 [ ℓ , r ] [\ell,r] [,r]上的最长不互素区间,老规矩设中点 m = ⌊ ℓ + r 2 ⌋ m=\left\lfloor\frac{\ell+r}2\right\rfloor m=2+r,将区间分成 [ ℓ , m ] [\ell,m] [,m] [ m + 1 , r ] [m+1,r] [m+1,r]两部分。不互素的区间要么是 [ ℓ , m ] [\ell,m] [,m]的子区间,要么是 [ m + 1 , r ] [m+1,r] [m+1,r]的子区间,要么左端点在 [ ℓ , m ] [\ell,m] [,m]中而右端点在 [ m + 1 , r ] [m+1,r] [m+1,r]中。对于前两种情况,直接递归调用,我们主要考虑第三种情况。

仍然取一个左端点 i ∈ [ ℓ , m ] i\in[\ell,m] i[,m],考虑找到 [ m + 1 , r ] [m+1,r] [m+1,r]中最远的右端点 j j j,使得区间 [ i , j ] [i,j] [i,j]上的元素不互素而区间 [ i , j + 1 ] [i,j+1] [i,j+1]上的元素互素。假如我们已经求出了这个 j = j ( i ) j=j(i) j=j(i),当左端点 i i i变成 i + 1 i+1 i+1时,能否快速求出 j ( i + 1 ) j(i+1) j(i+1)呢?基于之前的讨论, j ( i + 1 ) ⩾ j ( i ) j(i+1)\geqslant j(i) j(i+1)j(i)仍然是成立的,我们每次只要将右端点 j j j继续右移即可,不需要从 i + 1 i+1 i+1开始找。现在我们仍然面对快速查询 gcd ⁡ \gcd gcd的问题,但是我们发现要查询 gcd ⁡ \gcd gcd的区间 [ i , j ] [i,j] [i,j]一定满足 i ⩽ m , m + 1 ⩽ j i\leqslant m,m+1\leqslant j im,m+1j,这个区间其实是 [ ℓ , m ] [\ell,m] [,m]的一段后缀和 [ m + 1 , r ] [m+1,r] [m+1,r]的一段前缀。如果设
p ( i ) = gcd ⁡ ( d i , d i + 1 , ⋯   , d m ) , p(i)=\gcd(d_i,d_{i+1},\cdots,d_m), p(i)=gcd(di,di+1,,dm),

q ( j ) = gcd ⁡ ( d m + 1 , d m + 2 , ⋯   , d j ) , q(j)=\gcd(d_{m+1},d_{m+2},\cdots,d_j), q(j)=gcd(dm+1,dm+2,,dj),

则区间 [ i , j ] [i,j] [i,j]上的 gcd ⁡ \gcd gcd就是 gcd ⁡ ( p ( i ) , q ( j ) ) \gcd(p(i),q(j)) gcd(p(i),q(j))。而根据式 ( 2.1 ) (2.1) (2.1)的思想, p ( i ) p(i) p(i) q ( j ) q(j) q(j)又满足如下递推关系
p ( i ) = gcd ⁡ ( p ( i + 1 ) , d i ) , p(i)=\gcd(p(i+1),d_i), p(i)=gcd(p(i+1),di),

q ( j ) = gcd ⁡ ( q ( j − 1 ) , d j ) . q(j)=\gcd(q(j-1),d_j). q(j)=gcd(q(j1),dj).

因此,假设 r − ℓ + 1 = n ′ r-\ell+1=n^\prime r+1=n,我们可以花 O ( n ′ log ⁡ A ) O(n^\prime\log A) O(nlogA)的时间先预处理所有 p ( i ) , q ( j ) p(i),q(j) p(i),q(j),然后设置一个左端点 i i i和右端点 j j j,在右移左端点 i i i的时候不断尝试右移右端点,这个过程中需要不断检查 gcd ⁡ \gcd gcd是否为 1 1 1,而每次检查都只需要 O ( log ⁡ A ) O(\log A) O(logA)的时间。由于 i i i j j j都是单向移动的,每个位置都至多被 i i i j j j各经过一次,所以总的时间复杂度是 O ( n ′ log ⁡ A ) O(n^\prime\log A) O(nlogA)。于是整个分治算法的时间复杂度为 O ( n log ⁡ n log ⁡ A ) O(n\log n\log A) O(nlognlogA)

实现的时候要极其小心,这个移动双端点的过程有一些细节容易写错。

#include <cstdio>
#include <cmath>

constexpr int maxn = 2e3 + 7;

int s[maxn], d[maxn], n, T;
int p[maxn], q[maxn];
int ans;

int gcd(int a, int b) {
  return b ? gcd(b, a % b) : a;
}

void solve(int l, int r) {
  if (l > r)
    return;
  if (r - l + 1 < ans)
    return;
  if (l == r) {
    if (d[l] != 1 && 2 > ans)
      ans = 2;
    return;
  }
  int mid = (l + r) >> 1;
  solve(l, mid);
  solve(mid + 1, r);
  if (gcd(d[mid], d[mid + 1]) == 1)
    return;
  p[mid] = d[mid];
  for (int i = mid - 1; i >= l; --i)
    p[i] = gcd(p[i + 1], d[i]);
  q[mid + 1] = d[mid + 1];
  for (int i = mid + 2; i <= r; ++i)
    q[i] = gcd(q[i - 1], d[i]);
  for (int i = l, j = mid; i <= mid; ++i) {
    while (j < r && gcd(p[i], q[j + 1]) != 1)
      ++j;
    if (j > mid && j - i + 2 > ans)
      ans = j - i + 2;
  }
}

int main() {
  scanf("%d", &T);
  while (T--) {
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) {
      scanf("%d", s + i);
      d[i] = abs(s[i] - s[i - 1]);
    }
    ans = 1;
    solve(2, n);
    printf("%d\n", ans);
  }
  return 0;
}
算法3:二分答案

前两种算法都是基于所谓的“尺取法”,即使用两个单调移动的指针来解决问题。但如果你学会了算法1中快速查询区间 gcd ⁡ \gcd gcd的方式,我们就能得到下面这种写起来细节更少,不容易写错的算法。我们发现,如果有一个长度为 5 5 5的不互素区间,那么一定存在一个长度为 3 3 3的不互素区间。也就是说,这里的答案蕴含了某种单调性在里面。

二分是我们都玩过的一个游戏:Alice心里想了一个 [ 1 , 100 ] [1,100] [1,100]之间的整数,Bob来猜,每次Alice可以告诉Bob他猜大了还是猜小了。Bob每次都猜当前可行区间的中点,就保证每次都将可行区间缩短至少一半,于是至多 7 = ⌈ log ⁡ 2 100 ⌉ 7=\lceil\log_2100\rceil 7=log2100次就一定能猜出答案。

现在我们用这种二分的方式来猜最长的不互素区间长度是多少。比方说我猜 m m m,那接下来就要检查是否存在一个长度为 m m m的不互素区间。检查的方式非常简单,我们枚举区间的左端点 i i i,看看 [ i , i + m − 1 ] [i,i+m-1] [i,i+m1]这段区间上的 gcd ⁡ \gcd gcd是否等于 1 1 1就可以了。使用ST表可以做到 O ( log ⁡ A ) O(\log A) O(logA)查询,那么检查的复杂度是 O ( n log ⁡ A ) O(n\log A) O(nlogA),二分的复杂度 O ( log ⁡ n ) O(\log n) O(logn),总复杂度 O ( n log ⁡ n log ⁡ A ) O(n\log n\log A) O(nlognlogA),ST表预处理的复杂度也是 O ( n log ⁡ n log ⁡ A ) O(n\log n\log A) O(nlognlogA),所以整个算法的复杂度 O ( n log ⁡ n log ⁡ A ) O(n\log n\log A) O(nlognlogA)

#include <cstdio>
#include <cmath>

constexpr int maxn = 2e3 + 7;

int s[maxn], d[maxn], n, T;
int f[maxn][1 << 14], g[maxn];

int gcd(int a, int b) {
  return b ? gcd(b, a % b) : a;
}

inline void init() {
  for (int i = 2; i <= n; ++i)
    g[i] = g[i / 2] + 1;
  for (int i = 1; i <= n; ++i)
    f[i][0] = d[i];
  for (int j = 1; 1 << j <= n; ++j)
    for (int i = 1; i + (1 << j) - 1 <= n; ++i)
      f[i][j] = gcd(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}

inline int query(int l, int r) {
  int k = g[r - l + 1];
  return gcd(f[l][k], f[r - (1 << k) + 1][k]);
}

inline bool check(int mid) {
  for (int i = 2; i + mid - 1 <= n; ++i)
    if (query(i, i + mid - 1) != 1)
      return true;
  return false;
}

int main() {
  scanf("%d", &T);
  while (T--) {
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) {
      scanf("%d", s + i);
      d[i] = abs(s[i] - s[i - 1]);
    }
    init();
    // 二分答案的写法有很多种,有的人没有这个 ans 变量,有的人判断条件是 left < right,
    // 有的人修改的时候是 left = mid 或 right = mid,等等。
    // 最重要的是选定一种写法然后一直坚持,这样不容易写错。下面是我自己一直使用的写法。
    int left = 1, right = n - 1, ans = 0;
    while (left <= right) {
      int mid = (left + right) / 2;
      if (check(mid)) {
        ans = mid;
        left = mid + 1;
      } else
        right = mid - 1;
    }
    printf("%d\n", ans + 1);
  }
  return 0;
}

此外,我们还可以枚举左端点 i i i,这时它所对应的最远右端点 j j j也是单调的,也可以二分出来。时间复杂度是一样的。

2003

题意:有 c c c只鸽子,每只鸽子有一个饥饿程度 a i a_i ai和需要的食物数量 b i b_i bi。现在要选出 n n n只来喂,但你拥有的食物只有 f f f,也就是说所有选中的鸽子的 b b b值之和不能超过 f f f。在这种情况下最大化选中的鸽子的 a a a值的中位数。保证 n n n是奇数。无解输出 − 1 -1 1 n ⩽ 1 0 5 , n ⩽ c ⩽ 2 × 1 0 5 , 0 ⩽ f ⩽ 2 × 1 0 9 , 0 ⩽ a i ⩽ 2 × 1 0 9 , 0 ⩽ b i ⩽ 1 0 5 n\leqslant 10^5,n\leqslant c\leqslant 2\times 10^5,0\leqslant f\leqslant 2\times 10^9,0\leqslant a_i\leqslant 2\times 10^9,0\leqslant b_i\leqslant 10^5 n105,nc2×105,0f2×109,0ai2×109,0bi105

我特意把全部的数据范围放上来了,因为稍加计算会发现,此题的程序中某些变量的值可能超过int所能表示的最大值。在OI赛事中不乏这样的例子,一些正确的程序因为没有开long long而丢掉了部分分数,非常可惜,大家一定要小心。本题中会涉及到将若干个 b b b值相加,最坏情况下则是将 1 0 5 10^5 105 b b b值相加,会达到 1 0 10 10^{10} 1010,而int的最大值是 2 31 − 1 = 2147483647 2^{31}-1=2147483647 2311=2147483647,所以某些变量需要用long long。但我们不建议非常简单粗暴地将所有整型都开成long long,更不建议效仿某些人#define int long long的做法,因为那会额外增加很多空间开销,也会降低程序可读性。

我们考虑将所有鸽子按 a a a值从大到小排序,然后枚举中位数来自哪只鸽子。假设中位数是 a i a_i ai,意味着我们要在 [ 1 , i − 1 ] [1,i-1] [1,i1] [ i + 1 , c ] [i+1,c] [i+1,c]中分别选取 k = ( n − 1 ) / 2 k=(n-1)/2 k=(n1)/2只鸽子,要求选出的所有鸽子的 b b b值之和不超过 f f f,显然这些鸽子的 b b b值应该越小越好。因此我们需要在 [ 1 , i − 1 ] [1,i-1] [1,i1]这些鸽子中找出前 k k k小的 b b b值之和,在 [ i + 1 , c ] [i+1,c] [i+1,c]这些鸽子中找出前 k k k小的 b b b值之和,然后跟 b i b_i bi加在一起判断是否小于等于 f f f

现在问题就在于,如何快速查询这两段区间中前 k k k小的 b b b值。你当然可以直接用主席树、平衡树秒杀,正所谓“智商不够,数据结构来凑”,不过我们还是说一说课程范围以内的做法。

g ( i ) g(i) g(i)表示区间 [ 1 , i − 1 ] [1,i-1] [1,i1]上的鸽子中最小的 k k k b b b值之和,这里 i ⩾ k + 1 i\geqslant k+1 ik+1。维护一个始终都恰好有 k k k个元素的最大堆,一开始先把 b 1 , ⋯   , b k b_1,\cdots,b_k b1,,bk加入这个堆,同时记录当前堆中的元素之和 s u m sum sum。在 i i i k + 1 k+1 k+1起逐渐递增的过程中,每次 g ( i ) g(i) g(i)就等于当前的 s u m sum sum,然后判断 b i b_i bi是否应该加入堆。注意,这时的堆中的元素应为 b 1 , ⋯   , b i − 1 b_1,\cdots,b_{i-1} b1,,bi1中最小的 k k k个,而堆顶正好是这 k k k个中的最大值,也就是第 k k k小值。如果 b i b_i bi小于堆顶,就令 s u m sum sum减去堆顶,把堆顶弹掉,把 b i b_i bi加入堆,并令 s u m sum sum增加 b i b_i bi;否则就不将 b i b_i bi加入堆。这样一来,堆中的元素就变成了 b 1 , ⋯   , b i b_1,\cdots,b_i b1,,bi中最小的 k k k个,下一轮迭代就可以用了。记 h ( i ) h(i) h(i)表示区间 [ i + 1 , c ] [i+1,c] [i+1,c]上的鸽子中最小的 k k k b b b值之和,其中 i ⩽ c − k i\leqslant c-k ick。将以上过程倒着做一遍,就求出了所有的 h ( i ) h(i) h(i)

在实现中,因为我们最终枚举中位数的时候 i i i是从小到大枚举的,所以 g ( i ) g(i) g(i)可以不用开数组预处理,一边枚举中位数一边算即可。

#include <cstdio>
#include <functional>
#include <iterator>

namespace gkxx {

template <typename ForwardIterator, typename Less>
void __inplace_merge(
    ForwardIterator begin, ForwardIterator mid, ForwardIterator end,
    typename std::iterator_traits<ForwardIterator>::difference_type dist,
    Less less) {
  ForwardIterator i = begin, j = mid;
  using value_type = typename std::iterator_traits<ForwardIterator>::value_type;
  value_type *tmp = new value_type[dist](), *k = tmp;
  while (i != mid && j != end) {
    if (less(*i, *j))
      *k++ = *i++;
    else
      *k++ = *j++;
  }
  while (i != mid)
    *k++ = *i++;
  while (j != end)
    *k++ = *j++;
  k = tmp;
  while (begin != end)
    *begin++ = *k++;
  delete[] tmp;
}

template <typename ForwardIterator, typename Less>
void merge_sort(ForwardIterator begin, ForwardIterator end, Less less) {
  auto dist = std::distance(begin, end);
  if (dist <= 1)
    return;
  ForwardIterator mid = std::next(begin, dist / 2);
  merge_sort(begin, mid, less);
  merge_sort(mid, end, less);
  __inplace_merge(begin, mid, end, dist, less);
}

template <typename ForwardIterator>
inline void merge_sort(ForwardIterator begin, ForwardIterator end) {
  merge_sort(begin, end, std::less<void>());
}

} // namespace gkxx

constexpr int maxn = 2e5 + 7;

struct Pigeon {
  int a, b;
};

Pigeon pg[maxn];
int f, c, n;

int heap[maxn], size;

void push(int x) {
  int i = ++size;
  while (i > 1) {
    int fa = i / 2;
    if (heap[fa] < x)
      heap[i] = heap[fa];
    else
      break;
    i = fa;
  }
  heap[i] = x;
}

void pop() {
  int x = heap[size--];
  int i = 1;
  while (i * 2 <= size) {
    int ch = i * 2;
    if (ch + 1 <= size && heap[ch + 1] > heap[ch])
      ++ch;
    if (heap[ch] > x) {
      heap[i] = heap[ch];
      i = ch;
    } else
      break;
  }
  heap[i] = x;
}

long long h[maxn];

int main() {
  scanf("%d%d%d", &n, &c, &f);
  for (int i = 1; i <= c; ++i)
    scanf("%d%d", &pg[i].a, &pg[i].b);
  gkxx::merge_sort(
      pg + 1, pg + c + 1,
      [](const Pigeon &lhs, const Pigeon &rhs) -> bool { return lhs.a > rhs.a; });
  long long sum = 0;
  int k = (n - 1) / 2;
  for (int i = c; i >= c - k + 1; --i) {
    sum += pg[i].b;
    push(pg[i].b);
  }
  for (int i = c - k; i >= k + 1; --i) {
    h[i] = sum;
    if (pg[i].b < heap[1]) {
      sum -= heap[1];
      pop();
      sum += pg[i].b;
      push(pg[i].b);
    }
  }
  size = 0;
  sum = 0;
  for (int i = 1; i <= k; ++i) {
    sum += pg[i].b;
    push(pg[i].b);
  }
  for (int i = k + 1; i <= c - k; ++i) {
    if (h[i] + sum + pg[i].b <= f) {
      printf("%d\n", pg[i].a);
      return 0;
    }
    if (pg[i].b < heap[1]) {
      sum -= heap[1];
      pop();
      sum += pg[i].b;
      push(pg[i].b);
    }
  }
  puts("-1");
  return 0;
}

关于这题还有一个有意思的事:2018年我在省常中集训的时候做到了这道题,但当时题目并不保证 n n n是奇数,并且定义当 n n n是偶数时中位数为中间两个数的平均数。事实上偶数的情况要比奇数复杂一些,当时一些人就想:要不我就写个奇数的情况吧,应该能混到一部分分数。最后只写了奇数的人得了零分,因为测试数据全是偶数…

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值