【题目链接】
【思路要点】
- 枚举周期 p p p ,考虑如何判断其是否合法。
- 记 N = a + b , r = ⌊ N p ⌋ N=a+b,r=\lfloor\frac{N}{p}\rfloor N=a+b,r=⌊pN⌋ ,那么最后一段有 N % r N\% r N%r 个字符,其中至少有 a % r a\% r a%r 个 A A A , b % r b\%r b%r 个 B B B ,因此一个必要条件为 N % r ≥ a % r + b % r N\%r\geq a\%r+b\%r N%r≥a%r+b%r ,记 d e l t a = N % r − a % r − b % r delta=N\%r-a\%r-b\%r delta=N%r−a%r−b%r。
- 记在整段中的 A , B A,B A,B 的个数为 a 0 , b 0 a_0,b_0 a0,b0 ,在最后一段的 A , B A,B A,B 的个数为 a 1 , b 1 a_1,b_1 a1,b1 ,使得字符串存在长度为 p p p 的周期的充要条件是 a 0 ≥ a 1 , b 0 ≥ b 1 a_0\geq a_1,b_0\geq b_1 a0≥a1,b0≥b1 ,其必要性显然,并且不难构造。
- 因此,我们需要 ⌊ a r ⌋ ≥ a % r , ⌊ b r ⌋ ≥ b % r \lfloor\frac{a}{r}\rfloor\geq a\%r,\lfloor\frac{b}{r}\rfloor\geq b\%r ⌊ra⌋≥a%r,⌊rb⌋≥b%r ,且 ⌊ ⌊ a r ⌋ − a % r r + 1 ⌋ + ⌊ ⌊ b r ⌋ − b % r r + 1 ⌋ ≥ d e l t a r \lfloor\frac{\lfloor\frac{a}{r}\rfloor- a\%r}{r+1}\rfloor+\lfloor\frac{\lfloor\frac{b}{r}\rfloor- b\%r}{r+1}\rfloor\geq \frac{delta}{r} ⌊r+1⌊ra⌋−a%r⌋+⌊r+1⌊rb⌋−b%r⌋≥rdelta ,注意这里 d e l t a delta delta 一定是 r r r 的倍数。
- 至此,我们有了一个 O ( a + b ) O(a+b) O(a+b) 的做法。
- 不难注意到多数条件中只出现了 r r r ,枚举 r r r ,则可以得到一个 p p p 的范围,据此计算答案即可。
- 时间复杂度 O ( a + b ) O(\sqrt{a+b}) O(a+b) 。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 2e5 + 5; typedef long long ll; typedef long double ld; typedef unsigned long long ull; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } void force() { int a, b; read(a), read(b); int n = a + b, ans = 0; for (int i = 1; i <= n; i++) { int r = n / i, delta = (n % i - (a % r + b % r)) / r; if (delta >= 0 && a / r >= a % r && b / r >= b % r && delta - (a / r - a % r) / (r + 1) - (b / r - b % r) / (r + 1) <= 0) ans++; } writeln(ans); } int main() { int a, b; read(a), read(b); int n = a + b, ans = 0; for (int i = 1, nxt; i <= n; i = nxt + 1) { int r = n / i; nxt = n / r; if (a / r >= a % r && b / r >= b % r) { int Min = (a % r + b % r); int Max = ((a / r - a % r) / (r + 1) + (b / r - b % r) / (r + 1)) * r + (a % r + b % r); int rMin = (n - Max) / r + ((n - Max) % r != 0); int rMax = (n - Min) / r; chkmax(rMin, i); chkmin(rMax, nxt); if (rMin <= rMax) ans += rMax - rMin + 1; } } writeln(ans); return 0; }