n % d + m % d ≥ d n \% d + m \% d \geq d n%d+m%d≥d 可以写成 n − ⌊ n d ⌋ ∗ d + m − ⌊ m d ⌋ ∗ d ≥ d n - \lfloor\frac{n}{d}\rfloor * d +m - \lfloor\frac{m}{d}\rfloor * d \geq d n−⌊dn⌋∗d+m−⌊dm⌋∗d≥d
移项一下可以得出: ⌊ n d ⌋ + ⌊ m d ⌋ + 1 ≥ ⌊ n + m d ⌋ \lfloor\frac{n}{d}\rfloor + \lfloor\frac{m}{d}\rfloor + 1\geq \lfloor\frac{n + m}{d}\rfloor ⌊dn⌋+⌊dm⌋+1≥⌊dn+m⌋
枚举 d d d , c h e c k check check 这个条件,这个条件可以整除分块。注意到 d d d 的 上限是 n + m n + m n+m,设 n ≤ m n \leq m n≤m,将答案分成三个部分进行分块统计,即: 1 ≤ d ≤ n , n + 1 ≤ d ≤ m , m + 1 ≤ d ≤ n + m 1 \leq d \leq n,n + 1 \leq d \leq m,m + 1 \leq d \leq n + m 1≤d≤n,n+1≤d≤m,m+1≤d≤n+m
最后一个块整个块都是答案。
现在只需要能快速处理块间数字的因数和即可。
∑
x
=
1
n
σ
(
x
)
=
∑
x
=
1
n
∑
d
∣
x
d
=
∑
d
=
1
n
d
∑
x
=
1
⌊
n
d
⌋
1
=
∑
d
=
1
n
d
∗
⌊
n
d
⌋
\displaystyle\sum_{x = 1}^n\sigma(x) = \displaystyle\sum_{x = 1}^n\sum_{d | x}d = \displaystyle\sum_{d = 1}^nd\sum_{x = 1}^{\lfloor\frac{n}{d}\rfloor}1=\displaystyle\sum_{d = 1}^nd * \lfloor\frac{n}{d}\rfloor
x=1∑nσ(x)=x=1∑nd∣x∑d=d=1∑ndx=1∑⌊dn⌋1=d=1∑nd∗⌊dn⌋
这个式子也可以分块,但这两个分块并不是嵌套的形式,直接暴力做复杂度会达到
1
0
9
10^9
109,而不是
n
3
4
n^{\frac{3}{4}}
n43
用线性筛先预处理一部分(大概
1
0
7
10^7
107),把复杂度降下来,然后就可以过了
代码:
#include<bits/stdc++.h>
using namespace std;
const int mod = 1e9 + 7;
typedef long long ll;
ll n,m;
const int N=1e7 + 5;
bool mark[N];
int prim[N];
ll sd[N],sp[N];
int cnt;
void initial()
{
cnt=0;
sd[1]=1;
for (int i=2 ; i<N ; ++i)
{
if (!mark[i])
{
prim[cnt++]=i;
sd[i]=i+1;
sp[i]=i+1;
}
for (int j=0 ; j<cnt && i*prim[j]<N ; ++j)
{
mark[i*prim[j]]=1;
if (!(i%prim[j]))
{
sp[i*prim[j]]=sp[i]*prim[j]+1;
sd[i*prim[j]]=sd[i]/sp[i]*sp[i*prim[j]];
break;
}
sd[i*prim[j]]=sd[i]*sd[prim[j]] % mod;
sp[i*prim[j]]=1+prim[j];
}
}
for(int i = 1; i < N; i++)
sd[i] = (sd[i] + sd[i - 1]) % mod;
}
ll fpow(ll a,ll b) {
ll r = 1;
while(b) {
if(b & 1) r = r * a % mod;
b >>= 1;
a = a * a % mod;
}
return r;
}
ll inv2;
ll cal(ll x) {
return 1ll * x * (x + 1) % mod * inv2 % mod;
}
ll getsum(ll x) {
if(x < N) return sd[x];
int l,r;
ll ans = 0;
for(l = 1; l <= x; l = r + 1) {
r = x / (x / l);
ans = ans + (cal(r) - cal(l - 1) + mod) % mod * (x / l) % mod;
ans %= mod;
}
return ans;
}
int main() {
initial();
inv2 = fpow(2,mod - 2);
scanf("%lld%lld",&n,&m);
if(n > m) swap(n,m);
ll l,r;
ll res = 0;
for(l = 1; l <= n; l = r + 1) {
r = min(n / (n / l),m / (m / l));
r = min(r,(n + m) / ((n + m) / l));
if((n / l + m / l + 1) <= (n + m) / l) {
res = (res + getsum(r) - getsum(l - 1) + mod) % mod;
}
}
for(l = n + 1; l <= m; l = r + 1) {
r = min(m / (m / l),(n + m) / ((n + m) / l));
if((m / l + 1) <= (m + n) / l)
res = (res + getsum(r) - getsum(l - 1) + mod) % mod;
}
res = (res + getsum(n + m) - getsum(m) + mod) % mod;
printf("%lld\n",res);
return 0;
}