【思路要点】
- 所求的式子是约数和的形式,考虑枚举一个约数 iii ,计算其被算了多少次,则答案为
∑i=1N∑j=1⌊Ni⌋[gcd(i,j)=1]i=∑i=1N∑j=1⌊Ni⌋∑g∣i,g∣jμ(g)i=∑g=1Nμ(g)g∑i=1⌊Ng2⌋i⌊Ng2i⌋\sum_{i=1}^{N}\sum_{j=1}^{\lfloor\frac{N}{i}\rfloor}[gcd(i,j)=1]i\\=\sum_{i=1}^{N}\sum_{j=1}^{\lfloor\frac{N}{i}\rfloor}\sum_{g\mid i,g\mid j}\mu(g)i\\=\sum_{g=1}^{\sqrt{N}}\mu(g)g\sum_{i=1}^{\lfloor\frac{N}{g^2}\rfloor}i\lfloor\frac{N}{g^2i}\rfloori=1∑Nj=1∑⌊iN⌋[gcd(i,j)=1]i=i=1∑Nj=1∑⌊iN⌋g∣i,g∣j∑μ(g)i=g=1∑Nμ(g)gi=1∑⌊g2N⌋i⌊g2iN⌋- 不难发现靠后的求和符号为关于 ⌊Ng2⌋\lfloor\frac{N}{g^2}\rfloor⌊g2N⌋ 的函数 f(i)f(i)f(i) ,计算它的时间复杂度为 O(i)O(\sqrt{i})O(i) 。
- 因此该算法的时间复杂度为 O(∑g=1NNg2)=O(NLogN)O(\sum_{g=1}^{\sqrt{N}}\sqrt{\frac{N}{g^2}})=O(\sqrt{N}LogN)O(∑g=1Ng2N)=O(NLogN) 。
- 需要一定常数优化。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 4e6 + 5; typedef long long ll; typedef long double ld; typedef unsigned long long ull; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } ll n, Q; int limit, ans, P, d[MAXN], g[MAXN]; int tot, prime[MAXN], f[MAXN], miu[MAXN]; void update(int &x, int y) { x += y; if (x >= P) x -= P; } void sieve(int n) { miu[1] = d[1] = 1; for (int i = 2; i <= n; i++) { if (f[i] == 0) { prime[++tot] = f[i] = g[i] = i; d[i] = i + 1, miu[i] = P - 1; for (int j = i; 1ll * i * j <= n; j = i * j) d[i * j] = (1ll * i * d[j] + 1) % P; } d[i] = 1ll * d[g[i]] * d[i / g[i]] % P; for (int j = 1; j <= tot && prime[j] <= f[i]; j++) { int tmp = prime[j] * i; if (tmp > n) break; if (prime[j] == f[i]) miu[tmp] = 0, g[tmp] = g[i] * prime[j]; else miu[tmp] = (P - miu[i]) % P, g[tmp] = prime[j]; f[tmp] = prime[j]; } } for (int i = 2; i <= n; i++) update(d[i], d[i - 1]); } int func(ll n) { if (n <= limit) return d[n]; ll ans = 0; int last = 0; for (ll i = 1, nxt; i <= n; i = nxt + 1) { ll tmp = n / i; nxt = n / tmp; int tnp = nxt % P; int now = 1ll * (tnp + 1) * tnp / 2 % P; ans += 1ll * (now - last + P) * (tmp % P); ans = (ans >= Q) ? (ans - Q) : ans; last = now; } return ans % P; } int main() { freopen("sum.in", "r", stdin); freopen("sum.out", "w", stdout); read(n), read(P), Q = 7ll * P * P; sieve(limit = sqrt(n) + 1); ll last = n + 1; int now = 0; for (int i = 1; i <= limit; i++) { int mul = 1ll * miu[i] * i % P; if (n / i / i != last) { last = n / i / i; now = func(last); } update(ans, 1ll * mul * now % P); } writeln(ans); return 0; }