题意
∑ i = 1 n ∑ j = 1 m lcm ( i , j , i + j ) \sum\limits_{i=1}^n\sum\limits_{j=1}^m\operatorname{lcm}(i,j,i+j) i=1∑nj=1∑mlcm(i,j,i+j)
n , m ≤ 1 0 10 n,m\leq 10^{10} n,m≤1010,3s。
题解
设 n < m n<m n<m
原式 = ∑ g = 1 min ( m , n ) ∑ i = 1 n / g ∑ j = 1 n / g g i j ( i + j ) [ g c d ( i , j ) = 1 ] = ∑ g = 1 n g ∑ i = 1 n / g ∑ j = 1 n / g i j ( i + j ) ∑ d ∣ gcd ( i , j ) μ ( d ) = ∑ g = 1 n g ∑ d = 1 n / g μ ( d ) ∑ i = 1 m / ( d g ) ∑ j = 1 n / ( d g ) d 3 i j ( i + j ) \begin{aligned} \text{原式}=&\sum\limits_{g=1}^{\min(m,n)}\sum\limits_{i=1}^{n/g}\sum\limits_{j=1}^{n/g}gij(i+j)[gcd(i,j)=1]\\ =&\sum\limits_{g=1}^{n}g\sum\limits_{i=1}^{n/g}\sum\limits_{j=1}^{n/g} ij(i+j)\sum\limits_{d\mid \gcd(i,j)}\mu(d)\\ =&\sum\limits_{g=1}^{n}g\sum\limits_{d=1}^{n/g}\mu(d)\sum\limits_{i=1}^{m/(dg)}\sum\limits_{j=1}^{n/(dg)}d^3ij(i+j) \end{aligned} 原式===g=1∑min(m,n)i=1∑n/gj=1∑n/ggij(i+j)[gcd(i,j)=1]g=1∑ngi=1∑n/gj=1∑n/gij(i+j)d∣gcd(i,j)∑μ(d)g=1∑ngd=1∑n/gμ(d)i=1∑m/(dg)j=1∑n/(dg)d3ij(i+j)
记 f ( n , m ) = ∑ i = 1 n ∑ j = 1 m i j ( i + j ) = ( ∑ i = 1 n ∑ j = 1 m i 2 j ) + ( ∑ i = 1 n ∑ j = 1 m i j 2 ) = ( ∑ i 2 ) ( ∑ j ) + ( ∑ i ) ( ∑ j 2 ) = n ( n + 1 ) ( 2 n + 1 ) 6 m ( m + 1 ) 2 + n ( n + 1 ) 2 m ( m + 1 ) ( 2 m + 1 ) 6 = 1 6 n m ( n + 1 ) ( m + 1 ) ( n + m + 1 ) \begin{aligned}f(n,m)=&\sum\limits_{i=1}^n\sum\limits_{j=1}^mij(i+j)\\=&(\sum\limits_{i=1}^n\sum\limits_{j=1}^mi^2j)+(\sum\limits_{i=1}^n\sum\limits_{j=1}^mij^2)\\=&(\sum i^2)(\sum j)+(\sum i)(\sum j^2)\\=&{n(n+1)(2n+1)\over 6}{m(m+1)\over 2}+{n(n+1)\over 2}{m(m+1)(2m+1)\over 6}\\=&\dfrac16nm(n+1)(m+1)(n+m+1)\end{aligned} f(n,m)=====i=1∑nj=1∑mij(i+j)(i=1∑nj=1∑mi2j)+(i=1∑nj=1∑mij2)(∑i2)(∑j)+(∑i)(∑j2)6n(n+1)(2n+1)2m(m+1)+2n(n+1)6m(m+1)(2m+1)61nm(n+1)(m+1)(n+m+1)
原式 = ∑ g = 1 n g ∑ d = 1 n / g μ ( d ) ∑ i = 1 n / ( d g ) ∑ j = 1 n / ( d g ) d 3 i j ( i + j ) = ∑ g = 1 n g ∑ d = 1 n / g d 3 μ ( d ) f ( n / ( d g ) , m / ( d g ) ) = ∑ T = 1 n f ( n / T , m / T ) ∑ g ∣ T g ⋅ μ ( T g ) ⋅ ( T g ) 3 = ∑ T = 1 n f ( n / T , m / T ) ( I d ∗ ( μ ⋅ I d 3 ) ) ( T ) \begin{aligned} \text{原式}=&\sum\limits_{g=1}^{n}g\sum\limits_{d=1}^{n/g}\mu(d)\sum\limits_{i=1}^{n/(dg)}\sum\limits_{j=1}^{n/(dg)}d^3ij(i+j)\\ =&\sum\limits_{g=1}^ng\sum\limits_{d=1}^{n/g}d^3\mu(d)f(n/(dg),m/(dg))\\ =&\sum\limits_{T=1}^n f(n/T,m/T)\sum_{g|T}g\cdot\mu({T\over g})\cdot({T\over g})^3\\ =&\sum\limits_{T=1}^n f(n/T,m/T)(Id*(\mu\cdot Id^3))(T) \end{aligned} 原式====g=1∑ngd=1∑n/gμ(d)i=1∑n/(dg)j=1∑n/(dg)d3ij(i+j)g=1∑ngd=1∑n/gd3μ(d)f(n/(dg),m/(dg))T=1∑nf(n/T,m/T)g∣T∑g⋅μ(gT)⋅(gT)3T=1∑nf(n/T,m/T)(Id∗(μ⋅Id3))(T)
先整除分块,然后计算 I d ∗ ( μ ⋅ I d 3 ) Id*(\mu\cdot Id^3) Id∗(μ⋅Id3) 的前缀和。卷上 I d 3 Id^3 Id3 后等于 I d ∗ ϵ = I d Id*\epsilon=Id Id∗ϵ=Id,杜教筛即可。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
ll getint(){
ll ans=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')f=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
ans=ans*10+c-'0';
c=getchar();
}
return ans*f;
}
const int mod=1000000007;
const int inv2=(mod+1)/2,inv3=(mod+1)/3;
const int inv6=inv2*1ll*inv3%mod;
const int N=6e6;
int pri[N+10],sg[N+10],cnt=0;
bool boo[N+10];
int lim=0;
void init(int N){
lim=N;
sg[1]=1;
for(int i=2;i<=N;i++){
if(!boo[i]){
pri[cnt++]=i;
sg[i]=(mod-i*1ll*i%mod*i%mod+i)%mod;
}
for(int j=0;j<cnt&&i*pri[j]<=N;j++){
boo[i*pri[j]]=1;
if(i%pri[j]){
sg[i*pri[j]]=(sg[i]*1ll*sg[pri[j]])%mod;
}else{
sg[i*pri[j]]=(sg[i]*1ll*pri[j])%mod;
}
}
}
for(int i=1;i<=N;i++){
sg[i]+=sg[i-1];
if(sg[i]>=mod)sg[i]-=mod;
}
}
unordered_map<ll,int>sumg;
inline int sum_id(ll x){
x%=mod;
return (x*(x+1ll)/2%mod);
}
inline int sum_i(ll x){
return sum_id(x)*1ll*sum_id(x)%mod;
}
inline int f(ll n,ll m){
n%=mod;m%=mod;
return inv6*1ll*n%mod*m%mod*(n+1)%mod*(m+1)%mod*(n+m+1)%mod;
}
int calc(ll x){
if(x==0)return 0;
if(x<=lim)return sg[x];
auto it=sumg.find(x);
if(it!=sumg.end())return it->second;
int sum=sum_id(x);
int lst=1;
for(ll l=2,r=0;l<=x;l=r+1){
ll t=x/l;r=x/t;
int sr=sum_i(r);
sum=(sum+mod-(sr-lst+mod)*1ll*calc(t)%mod);
if(sum>=mod)sum-=mod;
lst=sr;
}
return sumg[x]=sum;
}
int main(){
//freopen("lcm.in","r",stdin);
//freopen("lcm.out","w",stdout);
ll n=getint(),m=getint();
init(min(N,(int)pow(min(n,m),0.7)));
int ans=0,lst=0;
//for(int i=1;i<=20;i++)cerr<<i<<" "<<(mod+calc(i)-calc(i-1))%mod<<endl;
for(ll l=1,r=0;l<=min(n,m);l=r+1){
r=min(n/(n/l),m/(m/l));
int cr=calc(r);
ans=(ans+(cr-lst+mod)*1ll*f(n/l,m/l)%mod);
if(ans>=mod)ans-=mod;
lst=cr;
}
cout<<ans;
return 0;
}