【题目链接】
【思路要点】
- 一个字符串存在 b o r d e r border border i i i 等价于其存在周期 N − i N-i N−i 。
- 记 p ( i ) p(i) p(i) 表示是否存在周期 i i i ,由期望的线性性,答案即为 ∑ i , j E ( p ( i ) p ( j ) ) \sum_{i,j}E(p(i)p(j)) ∑i,jE(p(i)p(j)) 。
- 考虑枚举 i , j i,j i,j ,如何计算 E ( p ( i ) p ( j ) ) E(p(i)p(j)) E(p(i)p(j)) ,首先给出结论, E ( p ( i ) p ( j ) ) = k max ( g c d ( i , j ) , i + j − N ) − N E(p(i)p(j))=k^{\max(gcd(i,j),i+j-N)-N} E(p(i)p(j))=kmax(gcd(i,j),i+j−N)−N
- 结论的证明如下:
记 c n t cnt cnt 表示在周期 i , j i,j i,j 的描述下字符串中连通块的个数,不难发现 E ( p ( i ) p ( j ) ) E(p(i)p(j)) E(p(i)p(j)) 即为 k c n t − N k^{cnt-N} kcnt−N 。
不失一般性地,我们假设 i < j i<j i<j 。
对于 i + j ≤ N i+j\leq N i+j≤N 的情况,一个字符 s x s_x sx 一定等于 s x − i ( x > i ) s_{x-i}\ (x>i) sx−i (x>i) 或 s x + j − i ( x ≤ i ) s_{x+j-i}\ (x\leq i) sx+j−i (x≤i) ,从而连通块的个数一定为 g c d ( i , j ) gcd(i,j) gcd(i,j) 。
对于 i + j > N i+j>N i+j>N 的情况,首先考虑 j j j 个元素 t 1 , 2 , … , j t_{1,2,\dots,j} t1,2,…,j 排成一个圆环,接着考虑字符串中每一对 s x = s x − i s_x=s_{x-i} sx=sx−i 的相等关系,它们分别描述了 t 1 = t i + 1 , t 2 = t i + 2 , … t_{1}=t_{i+1},t_{2}=t_{i+2},\dots t1=ti+1,t2=ti+2,… 。可以从在这个环上加边的过程中看出连通块的个数为 max ( g c d ( i , j ) , i + j − N ) \max(gcd(i,j),i+j-N) max(gcd(i,j),i+j−N) 。- 考虑加速计算,枚举 i + j = s i+j=s i+j=s 和其因数 g c d ( i , j ) = g gcd(i,j)=g gcd(i,j)=g ,我们需要计算合法的 ( i , j ) (i,j) (i,j) 的个数,也即和为 s g \frac{s}{g} gs ,各自大小在 N − 1 g \frac{N-1}{g} gN−1 以内的互质数对 ( i , j ) (i,j) (i,j) 的个数,可以用莫比乌斯反演解决。
- 时间复杂度 O ( ∑ i = 1 N ∑ j ∣ i d ( j ) ) O(\sum_{i=1}^{N}\sum_{j\mid i}d(j)) O(∑i=1N∑j∣id(j)) ,其中 d ( x ) d(x) d(x) 表示 x x x 的约数个数。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 2e5 + 5; const int P = 1e9 + 7; typedef long long ll; typedef long double ld; typedef unsigned long long ull; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } int power(int x, int y) { if (y == 0) return 1; int tmp = power(x, y / 2); if (y % 2 == 0) return 1ll * tmp * tmp % P; else return 1ll * tmp * tmp % P * x % P; } void update(int &x, int y) { x += y; if (x >= P) x -= P; } int miu[MAXN], powk[MAXN]; vector <int> factors[MAXN]; int tot, prime[MAXN], f[MAXN]; void sieve(int n) { miu[1] = 1; for (int i = 2; i <= n; i++) { if (f[i] == 0) prime[++tot] = f[i] = i, miu[i] = P - 1; for (int j = 1; j <= tot && prime[j] <= f[i]; j++) { int tmp = prime[j] * i; if (tmp > n) break; if (prime[j] == f[i]) miu[tmp] = 0; else miu[tmp] = (P - miu[i]) % P; f[tmp] = prime[j]; } } for (int i = 1; i <= n; i++) for (int j = i; j <= n; j += i) factors[j].push_back(i); } int calc(int lim, int sum) { int l = 1, r = lim; chkmax(l, sum - lim); chkmin(r, sum - 1); if (l > r) return 0; int ans = 0; for (auto x : factors[sum]) update(ans, 1ll * miu[x] * (P - (l - 1) / x + r / x) % P); return ans; } int main() { int n, k, ans = 0; read(n), read(k); sieve(2 * n), powk[0] = 1; for (int i = 1; i <= n; i++) powk[i] = 1ll * powk[i - 1] * k % P; for (int s = 2; s <= 2 * n - 2; s++) for (auto g : factors[s]) { int lim = (n - 1) / g, sum = s / g; if (lim <= 0 || sum <= 1) continue; update(ans, 1ll * calc(lim, sum) * powk[max(g, s - n)] % P); } writeln(1ll * ans * power(powk[n], P - 2) % P); return 0; }