题目链接
http://acm.zzuli.edu.cn/problem.php?id=2832
思路
图片是官方题解,官方题解省略了很多说明和细节,我来描述一下,先看第一句话,我们怎么求解这个
f
(
n
)
f(n)
f(n),我们可以很明显看出,应该是通过枚举x来得到个数,最后将所有x的取值的结果求和即为个数。
易得公式即为
∑
i
=
1
n
\sum_{i=1}^{n}
∑i=1n
⌊
n
/
i
⌋
\lfloor n/i \rfloor
⌊n/i⌋,通过题解我们得知我们要预处理这个公式1~1e5的所有值,那么哪怕使用数论分块一个个进行求解,时间复杂度也不尽人意,我们观察这个函数,可以得到,这个函数值实际上是区间
[
1
,
n
]
[1, n]
[1,n]所有数的约数个数的贡献之和,所以只要我们得到区间
[
1
,
n
]
[1, n]
[1,n]上每个数的约数个数再做一遍前缀和,就可以得到
f
(
n
)
f(n)
f(n)在
[
1
,
1
e
5
]
[1, 1e5]
[1,1e5]上的所有值。线性筛求解约数个数是参照这个链接。
解决了
f
(
n
)
f(n)
f(n)函数的求解和预处理,我们再回归本题,看官方题解第二行,在预处理这个函数值之后,我们
f
(
n
−
k
)
f(n-k)
f(n−k)的值就是
∑
i
=
1
n
−
k
\sum_{i=1}^{n-k}
∑i=1n−k
⌊
(
n
−
k
)
/
i
⌋
\lfloor (n-k)/i \rfloor
⌊(n−k)/i⌋,我们将所有这些取值对应的x, y的组合中的y加一个k就能保证
y
%
x
=
=
k
y\%x==k
y%x==k并且
y
<
=
n
y<=n
y<=n,但是这并不是最终答案,因为既然
y
%
x
=
=
k
y\%x==k
y%x==k,那么一定有
x
>
k
x>k
x>k,所以就有了官方题解上面的减去那个求和式。但是这也并不是最终答案,题解还是省略了一个值,在
k
!
=
0
k!=0
k!=0的情况下,那么还有
y
<
x
y<x
y<x的情况,这种情况对应的组合只有一种情况,
y
=
=
k
&
&
k
<
x
<
=
n
y==k\&\&k<x<=n
y==k&&k<x<=n,所有答案还要再加上一个
n
−
k
n-k
n−k(
k
=
=
0
k == 0
k==0就不用加了,因为y大于0)。
关于最后求解逆元,使用费马小定理或扩欧求解完后乘一个a即可,题目保证n不是23333的倍数,就是注意如果用费马小定理求解逆元,快速幂底数要先取模,底数可能达到1e10级别,不取模直接放进快速幂会爆long long。
费马小定理求解逆元代码
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize("inline")
#pragma GCC optimize("-fgcse-lm")
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long int
#define mod 23333
#define inf (2e18)
#define eps 1e-4
#define lson (p << 1)
#define rson ((p << 1) | 1)
using namespace std;
const int N = 1e5 + 2, M = 510;
inline ll read()
{
ll x = 0;
char ch = getchar();
while (ch < '0' || ch > '9')
{
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
x = (x << 1) + (x << 3) + (ch ^ 48);
ch = getchar();
}
return x;
}
ll qmi(ll a, ll b, ll p)
{
ll res = 1;
while (b)
{
if (b & 1)
res = res * a % p;
b >>= 1;
a = a * a % p;
}
return res;
}
ll prime[N], num[N], f[N], cnt;
bool vis[N];
void init()
{
f[1] = 1;
for (ll i = 2; i < N; i++)
{
if (!vis[i])
{
prime[++cnt] = i;
num[i] = 1;
f[i] = 2;
}
for (ll j = 1; j <= cnt && i * prime[j] < N; j++)
{
vis[i * prime[j]] = 1;
if (i % prime[j] == 0)
{
num[i * prime[j]] = 1 + num[i];
f[i * prime[j]] = f[i] / (num[i] + 1) * (num[i] + 2);
break;
}
num[i * prime[j]] = 1;
f[i * prime[j]] = f[i] * f[prime[j]];
}
}
for (ll i = 1; i < N; i++)
f[i] += f[i - 1];
}
int main()
{
init();
ll res = 0;
ll _;
_ = read();
while (_--)
{
ll n, k;
n = read(), k = read();
if (n <= k)
continue;
ll a = f[n - k];
if (k > 0)
a += (n - k);
ll b = n * n;
for (ll i = 1; i <= k; i++)
a = a - ((n - k) / i);
ll z = __gcd(a, b);
a /= z, b /= z;
ll x, y;
x = qmi(b % mod, mod - 2, mod);
x = x * a % mod;
// cout << x << '\n';
res ^= x;
}
cout << res;
return 0;
}
扩欧求解逆元代码
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize("inline")
#pragma GCC optimize("-fgcse-lm")
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long int
#define mod 23333
#define inf (2e18)
#define eps 1e-4
#define lson (p << 1)
#define rson ((p << 1) | 1)
using namespace std;
const int N = 1e5 + 2, M = 510;
inline ll read()
{
ll x = 0;
char ch = getchar();
while (ch < '0' || ch > '9')
{
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
x = (x << 1) + (x << 3) + (ch ^ 48);
ch = getchar();
}
return x;
}
ll qmi(ll a, ll b, ll p)
{
ll res = 1;
while (b)
{
if (b & 1)
res = res * a % p;
b >>= 1;
a = a * a % p;
}
return res;
}
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if (b == 0)
{
x = 1;
y = 0;
return a;
}
ll d = exgcd(b, a % b, x, y);
ll z = x;
x = y;
y = z - (a / b) * y;
return d;
}
ll prime[N], num[N], f[N], cnt;
bool vis[N];
void init()
{
f[1] = 1;
for (ll i = 2; i < N; i++)
{
if (!vis[i])
{
prime[++cnt] = i;
num[i] = 1;
f[i] = 2;
}
for (ll j = 1; j <= cnt && i * prime[j] < N; j++)
{
vis[i * prime[j]] = 1;
if (i % prime[j] == 0)
{
num[i * prime[j]] = 1 + num[i];
f[i * prime[j]] = f[i] / (num[i] + 1) * (num[i] + 2);
break;
}
num[i * prime[j]] = 1;
f[i * prime[j]] = f[i] * f[prime[j]];
}
}
for (ll i = 1; i < N; i++)
f[i] += f[i - 1];
}
int main()
{
init();
ll res = 0;
ll _;
_ = read();
while (_--)
{
ll n, k;
n = read(), k = read();
if (n <= k)
continue;
ll a = f[n - k];
if (k > 0)
a += (n - k);
ll b = n * n;
for (ll i = 1; i <= k; i++)
a = a - ((n - k) / i);
ll z = __gcd(a, b);
a /= z, b /= z;
ll x, y;
exgcd(b, mod, x, y);
x = (x % mod + mod) % mod;
x = x * a % mod;
// cout << x << '\n';
res ^= x;
}
cout << res;
return 0;
}