终于在考试中碰到了一题不能用杜教筛的函数,被迫来学这个。。。
概述
首先这个函数
f(x)
f
(
x
)
要求是积性函数,而且
f(p)
f
(
p
)
和
f(pc)
f
(
p
c
)
都要很好计算,设一个“假的”
f′(x)
f
′
(
x
)
表示把
x
x
直接当成质数时的,
f′(x)
f
′
(
x
)
是(或者能拆成)完全积性函数(比如说简单多项式),且
∑ni=1f′(x)
∑
i
=
1
n
f
′
(
x
)
要很好算。
min_25筛的过程有用到埃氏筛法(就是每次选一个质数筛掉其倍数的筛法)的思想,我们先把所有数当成质数,得到
∑ni=1f′(x)
∑
i
=
1
n
f
′
(
x
)
,然后不断地筛得到
∑ni=1[i∈P]f′(x)
∑
i
=
1
n
[
i
∈
P
]
f
′
(
x
)
也就是
∑ni=1[i∈P]f(x)
∑
i
=
1
n
[
i
∈
P
]
f
(
x
)
。得到这个所有质数函数值之和后,我们用类似的方法倒着推回去,得到真的函数值
∑ni=1f(x)
∑
i
=
1
n
f
(
x
)
。
筛质数的函数值
设质数集合
P={p1,p2,...,pt}
P
=
{
p
1
,
p
2
,
.
.
.
,
p
t
}
,
p2t≤n
p
t
2
≤
n
,
minp(x)
m
i
n
p
(
x
)
表示
x
x
的最小素因子。
设,显然对于
p2j>m
p
j
2
>
m
的
j
j
,都是相同的,只关心
p2j≤m
p
j
2
≤
m
的情况。
我们现在知道
g(m,0)
g
(
m
,
0
)
,要求
g(m,t)
g
(
m
,
t
)
。不难推出转移:
其中 −g(pj−1,j−1) − g ( p j − 1 , j − 1 ) 就是 −∑j−1i=1f′(pi) − ∑ i = 1 j − 1 f ′ ( p i ) ,去掉这些质数的之后就都是最小素因子恰好为 pj p j 的情况。注意到 p2j≤m p j 2 ≤ m ,所以 ⌊mpj⌋>pj−1 ⌊ m p j ⌋ > p j − 1 ,不会减出问题。
我们发现第一维只有 O(n−−√) O ( n ) 种,开两个数组分别记下 ≤n−−√ ≤ n 和 >n−−√ > n 的函数值,按照第一维从大到小转移(这样就不用记第二维)即可。
结果就是我们在 O(n34log(n)) O ( n 3 4 log ( n ) ) 的复杂度(并不知道为什么)的情况下得到了 ∑ni=1[i∈P]f′(i)=∑ni=1[i∈P]f(i) ∑ i = 1 n [ i ∈ P ] f ′ ( i ) = ∑ i = 1 n [ i ∈ P ] f ( i ) 。
筛所有数的函数值
类似的我们设
s(m,j)=∑mi=1[i∈P∨minp(i)≥pj]f(i)
s
(
m
,
j
)
=
∑
i
=
1
m
[
i
∈
P
∨
m
i
n
p
(
i
)
≥
p
j
]
f
(
i
)
。我们已知
s(m,t+1)=g(m,t)
s
(
m
,
t
+
1
)
=
g
(
m
,
t
)
要求
s(m,1)
s
(
m
,
1
)
。转移是:
这次要枚举次幂是因为 f f 不一定是完全积性函数,要把全部提出来。
类似之前的实现方法转移即可。
一种更快的做法
假设我们只需要
s(n,1)
s
(
n
,
1
)
而不需要第一维其它取值的答案,我们有一种常数更小的做法。
我们新设
s(m,j)=∑mi=1[minp(i)≥pj]f(i)
s
(
m
,
j
)
=
∑
i
=
1
m
[
m
i
n
p
(
i
)
≥
p
j
]
f
(
i
)
这里的 s s 用递归来实现,玄学的是不用记忆化都跑得比递推的快。
LOJ6053
注意到并不是什么简单多项式,但只有 f(2)=2+1 f ( 2 ) = 2 + 1 ,其它奇质数 f(p)=p−1 f ( p ) = p − 1 ,拆成 p p 和两部分筛,最后特判下 f(2) f ( 2 ) 即可。
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define N 200010
#define ll long long
#define up(x,y) (x=(x+(y))%mod)
using namespace std;
const int mod=1000000007,i2=(mod+1)/2;
ll n,qn,tot,g[2][N],s[2][N],h[2][N];
int pri[N];
bool flag[N];
ll sqrtt(ll n)
{
ll tmp=sqrt(n);
for(ll x=max(tmp-5,0ll);;x++)
if(x*x>n) return x-1;
}
ll get(ll a[2][N],ll x)
{
return (x<=qn?a[0][x]:a[1][n/x]);
}
ll f(ll p,ll k)
{
if(!k) return 1;
return p^k;
}
void getpri(int n)
{
flag[1]=1;
for(int i=2;i<=n;i++)
{
if(!flag[i]) pri[++tot]=i;
for(int j=1;j<=tot&&i*pri[j]<=n;j++)
{
flag[i*pri[j]]=1;
if(i%pri[j]==0) break;
}
}
}
ll sum(ll x)
{
x%=mod;
return x*(x+1)%mod*i2%mod;
}
void solg()
{
for(int i=1;i<=qn;i++)
g[0][i]=sum(i)-1,g[1][i]=sum(n/i)-1;
for(int i=1;i<=qn;i++)
h[0][i]=i-1,h[1][i]=n/i%mod-1;
for(int j=1;j<=tot;j++)
{
ll p=pri[j];
for(int i=1;i<=qn&&n/i>=p*p;i++)
up(g[1][i],-p*(get(g,n/(p*i))-get(g,pri[j-1])));
for(int i=qn;i>=p*p;i--)
up(g[0][i],-p*(get(g,i/p)-get(g,pri[j-1])));
for(int i=1;i<=qn&&n/i>=p*p;i++)
up(h[1][i],-(get(h,n/(p*i))-get(h,pri[j-1])));
for(int i=qn;i>=p*p;i--)
up(h[0][i],-(get(h,i/p)-get(h,pri[j-1])));
}
for(int i=1;i<=qn;i++)
up(g[0][i],-h[0][i]+2ll*(i>=2)),up(g[1][i],-h[1][i]+2);
}
void sols()
{
for(int i=1;i<=qn;i++)
s[0][i]=get(g,i),s[1][i]=get(g,n/i);
for(int j=tot;j;j--)
{
ll p=pri[j];
for(int i=1;i<=qn&&n/i>=p*p;i++)
for(ll d=p,k=1;d*p<=n/i;d*=p,k++)
up(s[1][i],f(p,k)*(get(s,n/(d*i))-get(s,p))+f(p,k+1));
for(int i=qn;i>=p*p;i--)
for(ll d=p,k=1;d*p<=i;d*=p,k++)
up(s[0][i],f(p,k)*(get(s,i/d)-get(s,p))+f(p,k+1));
}
}
int main()
{
scanf("%lld",&n);
qn=sqrtt(n);
getpri(qn);
solg();
sols();
printf("%lld",(s[1][1]+1)%mod);
return 0;
}
递归的:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define N 200010
#define ll long long
#define up(x,y) (x=(x+(y))%mod)
using namespace std;
const int mod=1000000007,i2=(mod+1)/2;
ll n,qn,tot,g[2][N],s[2][N],h[2][N];
int pri[N];
bool flag[N];
ll sqrtt(ll n)
{
ll tmp=sqrt(n);
for(ll x=max(tmp-5,0ll);;x++)
if(x*x>n) return x-1;
}
ll get(ll a[2][N],ll x)
{
return (x<=qn?a[0][x]:a[1][n/x]);
}
ll f(ll p,ll k)
{
if(!k) return 1;
return p^k;
}
void getpri(int n)
{
flag[1]=1;
for(int i=2;i<=n;i++)
{
if(!flag[i]) pri[++tot]=i;
for(int j=1;j<=tot&&i*pri[j]<=n;j++)
{
flag[i*pri[j]]=1;
if(i%pri[j]==0) break;
}
}
}
ll sum(ll x)
{
x%=mod;
return x*(x+1)%mod*i2%mod;
}
void solg()
{
for(int i=1;i<=qn;i++)
g[0][i]=sum(i)-1,g[1][i]=sum(n/i)-1;
for(int i=1;i<=qn;i++)
h[0][i]=i-1,h[1][i]=n/i%mod-1;
for(int j=1;j<=tot;j++)
{
ll p=pri[j];
for(int i=1;i<=qn&&n/i>=p*p;i++)
up(g[1][i],-p*(get(g,n/(p*i))-get(g,pri[j-1])));
for(int i=qn;i>=p*p;i--)
up(g[0][i],-p*(get(g,i/p)-get(g,pri[j-1])));
for(int i=1;i<=qn&&n/i>=p*p;i++)
up(h[1][i],-(get(h,n/(p*i))-get(h,pri[j-1])));
for(int i=qn;i>=p*p;i--)
up(h[0][i],-(get(h,i/p)-get(h,pri[j-1])));
}
for(int i=1;i<=qn;i++)
up(g[0][i],-h[0][i]+2ll*(i>=2)),up(g[1][i],-h[1][i]+2);
}
ll ask(ll n,int m)
{
ll p=pri[m];
if(n<p) return 0;
ll ans=get(g,n)-get(g,p-1);
if(n<p*p) return ans;
for(int i=m;i<=tot&&pri[i]*pri[i]<=n;i++)
for(ll d=pri[i],k=1;d*pri[i]<=n;d*=pri[i],k++)
up(ans,f(pri[i],k)*ask(n/d,i+1)+f(pri[i],k+1));
return ans;
}
int main()
{
scanf("%lld",&n);
qn=sqrtt(n);
getpri(qn<<1);
solg();
printf("%lld",(ask(n,1)+1)%mod);
return 0;
}