【题目链接】
【思路要点】
- 显然问题可以转化为\(K=1\)的形式。
- 那么,我们实际上要求\(\sum_{i_1,i_2,...,i_N=L}^{R}\epsilon(gcd(i_1,i_2,...,i_N))\)。
- \(=\sum_{i_1,i_2,...,i_N=L}^{R}\sum_{d/i_1,i_2,...,i_N}\mu(d)\)
- \(=\sum_{d=1}^{R}\mu(d)(\lfloor\frac{R}{d}\rfloor-\lfloor\frac{L-1}{d}\rfloor)^N\)
- 其中\((\lfloor\frac{R}{d}\rfloor-\lfloor\frac{L-1}{d}\rfloor)^N\)只有\(O(\sqrt{L}+\sqrt{R})\)种取值,而\(\mu(d)\)的前缀和可以通过杜教筛求得。
- 时间复杂度\(O(L^{\frac{2}{3}}+R^{\frac{2}{3}}+(\sqrt{L}+\sqrt{R})*LogN)\)。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 2e6 + 5; const int P = 1e9 + 7; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } int tot, prime[MAXN], f[MAXN]; int miu[MAXN], s[MAXN]; int sl[MAXN], sr[MAXN]; int n, l, r, k; void init() { miu[1] = s[1] = 1; for (int i = 2; i < MAXN; i++) { if (f[i] == 0) { f[i] = prime[++tot] = i; miu[i] = -1; } s[i] = (s[i - 1] + miu[i] + P) % P; for (int j = 1; j <= tot && prime[j] <= f[i]; j++) { int tmp = prime[j] * i; if (tmp >= MAXN) break; f[tmp] = prime[j]; if (prime[j] == f[i]) miu[tmp] = 0; else miu[tmp] = -miu[i]; } } } int getsl(int n) { int m = l / n, ans = 1; if (n < MAXN) return s[n]; else if (sl[m] != -1) return sl[m]; int now = 2; while (now <= n) { int nxt = n / (n / now) + 1; ans = (ans - 1ll * (nxt - now) * getsl(n / now) % P + P) % P; now = nxt; } return sl[m] = ans; } int getsr(int n) { int m = r / n, ans = 1; if (n < MAXN) return s[n]; else if (sr[m] != -1) return sr[m]; int now = 2; while (now <= n) { int nxt = n / (n / now) + 1; ans = (ans - 1ll * (nxt - now) * getsr(n / now) % P + P) % P; now = nxt; } return sr[m] = ans; } int power(int x, int y) { if (y == 0) return 1; int tmp = power(x, y / 2); if (y % 2 == 0) return 1ll * tmp * tmp % P; else return 1ll * tmp * tmp % P * x % P; } int main() { read(n), read(k), read(l), read(r); if (l % k == 0) l = l / k - 1; else l = l / k; r /= k; memset(sl, -1, sizeof(sl)); memset(sr, -1, sizeof(sr)); init(); int now = 1, last = 0, ans = 0; while (now <= r) { int nxtl, nxtr; if (now <= l) nxtl = l / (l / now); else nxtl = r; nxtr = r / (r / now); if (nxtr <= nxtl) { ans = (ans + 1ll * (getsr(nxtr) - last + P) * power(r / now - l / now + P, n)) % P; last = getsr(nxtr); now = nxtr + 1; } else { ans = (ans + 1ll * (getsl(nxtl) - last + P) * power(r / now - l / now + P, n)) % P; last = getsl(nxtl); now = nxtl + 1; } } writeln(ans); return 0; }