题目链接
题意:给出一个n和k,要求输出1到n中有多少个数的最大质因子都小于等于k
T组输入,n和k的范围都是1e9.
做法:如果
n
≤
k
n \leq k
n≤k 那么答案是
n
n
n.
如果
k
<
n
k< n
k<n:
- 首先考虑一个容斥很容易得到答案的表达如下:
- a n s = n − ∑ i 1 p i + ∑ i < j 1 p i p j − . . . ( − 1 ) k 1 p 1 p 2 . . . p k ( m i n ( p ) > k ) ans =n -\sum_{i}\frac{1}{p_i}+\sum_{i<j}\frac{1}{p_ip_j} -...(-1)^{k}\frac{1}{p_{1}p_{2}...p_{k}}(min(p)>k) ans=n−∑ipi1+∑i<jpipj1−...(−1)kp1p2...pk1(min(p)>k)
- 由上面那个式子如果知道容斥和莫比乌斯的肯定很熟悉,他就等于如下的式子:
- a n s = n + ∑ p > k μ ( i ) ⌊ n i ⌋ , p ∣ i . ans = n+\sum_{p>k}\mu(i)\lfloor \frac{n}{i}\rfloor,p|i. ans=n+∑p>kμ(i)⌊in⌋,p∣i.
- 上面的 p p p是 i i i的最小质因子,如果不清楚可以看看2013的国家集训队论文。
现在就是怎么求
∑
p
>
k
μ
(
i
)
⌊
n
i
⌋
\sum_{p>k}\mu(i)\lfloor \frac{n}{i}\rfloor
∑p>kμ(i)⌊in⌋。后面的整除我们可以用数论分块
O
(
n
)
O(\sqrt n)
O(n)来解决。但不过中间的莫比乌斯的值就不好求了。
中间的莫比乌斯函数的值其实就是最小质因子大于等于
k
k
k的数的莫比乌斯的函数值的和。我们考虑使用
m
i
n
25
min25
min25筛来解决这个问题。
因为
m
i
n
25
min25
min25的正是处理这些问题了。那么只需要找到最小的大于
k
k
k的质因子就可以筛出函数值了。
这里分类讨论:
- 首先 k > n k>\sqrt n k>n这种情况下,只要随便出现一个大于 k k k的质数那么肯定就是不合法的,那么对于数论分块的一段区间 [ l , r ] [l,r] [l,r]只需要判断有多少素数,然后乘上 ⌊ n l ⌋ \lfloor \frac{n}{l} \rfloor ⌊ln⌋就行了。
- 然后 k ≤ n k\leq\sqrt n k≤n这种情况下,我就考虑一个比较暴力的做法,因为第一个大于 k k k的素因子肯定很容易就能找到,那么对于数论分块的每一段区间 [ l , r ] [l,r] [l,r]的莫比乌斯函数值,就是 s ( r , p j ) − s ( l − 1 , p j ) , p j > k s(r,p_{j})-s(l-1,p_{j}),p_j>k s(r,pj)−s(l−1,pj),pj>k直接 m i n 25 min25 min25暴力求就行了。
这个复杂度我就不分析了,因为好像求和部分的复杂度不是很高,并没有TLE。题解和标程的代码我没有看懂,好像他只需要求一次 s s s数组可以 O ( 1 ) O(1) O(1)的得出答案了。
#include "bits/stdc++.h"
using namespace std;
inline int read() {
int x = 0;
bool f = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = 0;
for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
if (f) return x;
return 0 - x;
}
#define SZ(x) ((int)(x.size()))
#define all(x) (x).begin(),(x).end()
#define ll long long
const int maxn = 2e5 + 10;
const double PI = acos(-1.0);
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
int inv(int n) {
if (n == 1) return 1;
return 1ll * inv(mod % n) * (mod - mod / n) % mod;
}
int ksm(int a, ll n) {
int ans = 1;
while (n) {
if (n & 1) ans = 1ll * ans * a % mod;
n >>= 1;
a = 1ll * a * a % mod;
}
return ans;
}
inline int add(int u, int v) { return (u += v) >= mod ? u - mod : u; }
inline int sub(int u, int v) { return (u -= v) < 0 ? u + mod : u; }
inline int mul(int u, int v) { return 1ll * u * v % mod; }
int vis[maxn], pri[maxn], cnt, mu[maxn];
void init(int N = maxn - 1) {
mu[1] = 1;
cnt = 0;
for (int i = 2; i <= N; i++) {
if (!vis[i]) {
pri[++cnt] = i;
mu[i] = -1;
}
for (int j = 1; j <= cnt && i * pri[j] <= N; j++) {
vis[i * pri[j]] = 1;
if (i % pri[j] == 0) break;
mu[i * pri[j]] = -mu[i];
}
}
}
int id1[maxn], id2[maxn], sqr, n, k, c;
int w[maxn << 1], g[maxn << 1], h[maxn << 1];
inline int getid(int x, int m) {
if (x <= sqr) return id1[x];
return id2[m / x];
}
void calc_g(int _n) {
c = 0, sqr = sqrt(_n);
int m = _n;
for (int l = 1, r; l <= m; l = r + 1) {
r = m / (m / l);
w[++c] = m / l;
g[c] = w[c] - 1;
if (w[c] <= sqr) id1[w[c]] = c;
else id2[r] = c;
}
init(sqr << 1);
while (pri[cnt] > sqr) --cnt;
for (int j = 1; j <= cnt; j++) {
for (int i = 1; i <= c && pri[j] * pri[j] <= w[i]; i++) {
int id = getid(w[i] / pri[j], m);
g[i] -= g[id] - j + 1;
}
}
}
inline int fmu(int x) { return x == 1 ? -1 : 0; }
int get(int x, int y) {
if (x <= 1 || pri[y] > x) return 0;
int id = getid(x, n);
int ret = -g[id] + y - 1;
for (int i = y; i <= cnt && pri[i] * pri[i] <= x; i++) {
int t1 = pri[i], t2 = pri[i] * pri[i];
for (int e = 1; t2 <= x; ++e, t1 = t2, t2 *= pri[i])
ret += get(x / t1, i + 1) * fmu(e) + fmu(e + 1);
}
return ret;
}
int solve() {
int ans = n;
if (1ll * k * k > n) {
int res = 0, l = k + 1, r = n / (n / l);
calc_g(k);
res -= g[getid(l - 1, k)];
calc_g(n);
res += g[getid(r, n)];
ans -= res * (n / l);
l = r + 1;
for (; l <= n; l = r + 1) {
r = n / (n / l);
res = g[getid(r, n)] - g[getid(l - 1, n)];
ans -= res * (n / l);
}
return ans;
}
calc_g(n);
int t = 1;
while (t <= cnt && pri[t] <= k) ++t;
for (int l = k + 1, r; l <= n; l = r + 1) {
r = n / (n / l);
ans += (n / l) * (get(r, t) - get(l - 1, t));
// cout << (get(r, t) - get(l - 1, t)) << endl;
}
return ans;
}
int main() {
// freopen("1.in", "r", stdin);
int T;
scanf("%d", &T);
// init();
while (T--) {
scanf("%d%d", &n, &k);
int ans;
if (n <= k) ans = n;
else ans = solve();
printf("%d\n", ans);
}
return 0;
}