测试地址:神犇和蒟蒻
做法:本题需要用到杜教筛。
啊,差不多一年没碰过这东西了,想当初学这个东西学出心理阴影了都……然而不能因为菜就停下自己的脚步,所以先做一道杜教筛基础题复健一下。
对于这道题目,第一问就是玩的,显然当
i>1
i
>
1
时
μ(i2)=0
μ
(
i
2
)
=
0
,仅有
μ(1)=1
μ
(
1
)
=
1
,所以答案就是
1
1
。
对于第二问,根据欧拉函数的公式,我们知道,因此要求的就是这样一个积性函数前缀和:
∑ni=1iφ(i)
∑
i
=
1
n
i
φ
(
i
)
。按照杜教筛的套路,要找到一个好求前缀和的积性函数
g
g
,使得它和要求的函数(这道题中
f(n)=nφ(n)
f
(
n
)
=
n
φ
(
n
)
)的狄利克雷卷积也是一个好求前缀和的函数。这里我们找的函数是
g(n)=n
g
(
n
)
=
n
(在某些地方也写作
id
i
d
),因为显然这个函数是完全积性函数,而它和
f
f
的狄利克雷卷积:,也显然是一个完全积性函数,而且这两个函数都可以
O(1)
O
(
1
)
求出前缀和,那么令
S(n)=∑ni=1f(i)
S
(
n
)
=
∑
i
=
1
n
f
(
i
)
,套上杜教筛的公式:
g(1)S(n)=∑ni=1(f∗g)(i)−∑ni=2g(i)S(⌊ni⌋)
g
(
1
)
S
(
n
)
=
∑
i
=
1
n
(
f
∗
g
)
(
i
)
−
∑
i
=
2
n
g
(
i
)
S
(
⌊
n
i
⌋
)
直接杜教筛即可,注意杜教筛要预处理前
n23
n
2
3
项前缀和,要用哈希表处理记忆化,这样就可以做到
O(n23)
O
(
n
2
3
)
的复杂度了。
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=1000000007;
const ll hashsiz=2000003;
ll limit,n,phi[1000010],sum[1000010],prime[1000010];
ll hashlist[2000010]={0},hashval[2000010];
bool vis[1000010]={0};
void calc()
{
phi[1]=1;
prime[0]=0;
for(ll i=2;i<=limit;i++)
{
if (!vis[i])
{
prime[++prime[0]]=i;
phi[i]=i-1;
}
for(ll j=1;j<=prime[0]&&i*prime[j]<=limit;j++)
{
vis[i*prime[j]]=1;
if (i%prime[j]==0)
{
phi[i*prime[j]]=phi[i]*prime[j];
break;
}
phi[i*prime[j]]=phi[i]*(prime[j]-1);
}
}
sum[0]=0;
for(ll i=1;i<=limit;i++)
sum[i]=(sum[i-1]+i*phi[i])%mod;
}
ll sumg(ll n)
{
ll inv=500000004;
return n*(n+1)%mod*inv%mod;
}
ll sumfg(ll n)
{
ll inv=166666668;
return n*(n+1)%mod*(2*n+1)%mod*inv%mod;
}
void hashinsert(ll x,ll v)
{
ll pos=x%hashsiz;
while(hashlist[pos]&&hashlist[pos]!=x) pos++;
hashlist[pos]=x;
hashval[pos]=v;
}
ll hashfind(ll x)
{
ll pos=x%hashsiz;
while(hashlist[pos]&&hashlist[pos]!=x) pos++;
if (hashlist[pos]==x) return pos;
else return -1;
}
ll solve(ll n)
{
ll pos=hashfind(n);
if (n<=limit) return sum[n];
if (pos!=-1) return hashval[pos];
ll ans=sumfg(n);
for(ll i=n;i>=2;i=n/(n/i+1))
{
ll l=max(2ll,n/(n/i+1)+1),r=i;
ans-=(solve(n/i)*(sumg(r)-sumg(l-1))%mod+mod)%mod;
ans=(ans+mod)%mod;
}
hashinsert(n,ans);
return ans;
}
int main()
{
scanf("%lld",&n);
printf("1\n");
for(ll i=1;i*i*i<=n;i++)
if ((i+1)*(i+1)*(i+1)>n)
{
limit=i*i;
break;
}
calc();
printf("%lld",solve(n));
return 0;
}