题目
https://www.luogu.com.cn/problem/P5325
定义积性函数
f
(
x
)
f(x)
f(x),且
f
(
p
k
)
=
p
k
(
p
k
−
1
)
(
p
为
质
数
)
f(p^k)=p^k(p^k-1)(p为质数)
f(pk)=pk(pk−1)(p为质数),求
∑
i
=
1
n
f
(
i
)
\sum_{i=1}^{n}f(i)
i=1∑nf(i)
对
1
0
9
+
7
10^9+7
109+7取模。
n
≤
1
0
10
n \le 10^{10}
n≤1010。
思路
把
f
(
p
)
=
p
2
−
p
f(p)=p^2-p
f(p)=p2−p拆成两个完全积性函数
g
(
p
)
=
p
2
g(p)=p^2
g(p)=p2,
g
1
(
p
)
=
p
g_1(p)=p
g1(p)=p,注意
−
p
-p
−p不是完全积性函数
#include<bits/stdc++.h>
#define ll long long
#define mod 1000000007
using namespace std;
const int N=1000009;//别开小了
int isp[N],tot,cnt,sqr;
ll psum[N],psum1[N];//psum为质数前缀
ll g[N],g1[N];//g为积性函数,注意下标是离散化后的
ll n,prime[N];
ll w[N];//w存n/x的所有数
ll id[N],id1[N];//id,id1存n/x的离散化下标
ll qpow(ll a,ll b){ll res=1;a%=mod;while(b){if(b&1)res=res*a%mod;a=a*a%mod;b>>=1;}return res;}
ll calsum(ll x){//第一个积性函数求和,平方和,注意减f(1)
x%=mod;
return (x*(2*x+1)%mod*(x+1)%mod*qpow(6,mod-2)%mod-1+mod)%mod;
}
ll calsum1(ll x){//第二个积性函数求和,等差数列和,注意减f1(1)
x%=mod;
return ((1+x)*x%mod*qpow(2,mod-2)%mod-1+mod)%mod;
}
ll cal(ll x){//第一个积性函数
x%=mod;
return x*x%mod;
}
ll cal1(ll x){//第二个积性函数
x%=mod;
return x;
}
ll f(ll x){//f(p^e)
x%=mod;
return x*(x-1+mod)%mod;
}
int findl(ll x){
return x<=sqr?id[x]:id1[n/x];
}
ll S(ll x,int y){
if(prime[y]>=x)return 0;
int locate=findl(x);
ll res=(g[locate]-psum[y]+mod-(g1[locate]-psum1[y]+mod)+mod)%mod;//两个函数合并
for(int i=y+1;i<=tot&&prime[i]*prime[i]<=x;i++){//prime[i]*prime[i]<=x优化,否则会超时,但答案不会错
ll tmp=prime[i];
for(int e=1;tmp<=x;e++,tmp*=prime[i]){
ll tem=tmp%mod;//不先取余会溢出
res=(res+f(tem)*(S(x/tmp,i)+(e>1))%mod)%mod;
}
}
//printf("%lld %d %lld\n",x,y,res);
return res;
}
void solve(){
sqr=sqrt(n);
isp[1]=1;
tot=cnt=0;
psum[0]=psum1[0]=g[0]=g1[0]=0;
for(int i=2;i<=sqr;i++){
if(!isp[i])
prime[++tot]=i;
for(int j=1;j<=tot&&i*prime[j]<=sqr;j++){
isp[i*prime[j]]=1;
if(i%prime[j]==0)
break;
}
}
for(int i=1;i<=tot;i++)
psum[i]=(psum[i-1]+cal(prime[i]))%mod,psum1[i]=(psum1[i-1]+cal1(prime[i]))%mod;
for(ll l=1,r;l<=n;l=r+1){
r=n/(n/l);
ll tmp=n/l;//tmp要运算记得先取余
w[++cnt]=tmp;//tmp从大到小存
g[cnt]=calsum(tmp);
g1[cnt]=calsum1(tmp);
if(tmp<=sqr)
id[tmp]=cnt;
else
id1[n/tmp]=cnt;
//printf("%lld %lld %d\n",tmp,n/tmp,cnt);
}
for(int i=1;i<=tot;i++){
for(int j=1;j<=cnt&&prime[i]*prime[i]<=w[j];j++){//滚动dp,w[j]从大到小遍历
ll tmp=w[j]/prime[i];
int locate=findl(tmp);
g[j]=(g[j]-cal(prime[i])*(g[locate]-psum[i-1]+mod)%mod+mod)%mod;
g1[j]=(g1[j]-cal1(prime[i])*(g1[locate]-psum1[i-1]+mod)%mod+mod)%mod;
//printf("%lld %lld %d %lld\n",w[j],prime[i],locate,g[j]);
}
}
//cout<<endl;
printf("%lld",(S(n,0)+1)%mod);
}
int main(){
scanf("%lld",&n);
solve();
}