Description
给出
n,m
n
,
m
,求
其中, μ(i) μ ( i ) 为莫比乌斯函数。
Input
一行,两个整数,表示
n,m
n
,
m
。
Output
一行一个整数,表示答案(对
998244353
998244353
取模)。
Sample Input
10 3
Sample Output
714
Data Constraint
分析:
一道卡常题……
显然是用杜教筛了。令
f(i)=μ(i)∗im
f
(
i
)
=
μ
(
i
)
∗
i
m
,
g(i)=im
g
(
i
)
=
i
m
,则
(f∗g)(i)=im∗∑j|iμ(j)
(
f
∗
g
)
(
i
)
=
i
m
∗
∑
j
|
i
μ
(
j
)
。
所以只有当
i=1
i
=
1
时,
(f∗g)(i)=1
(
f
∗
g
)
(
i
)
=
1
,否则都为
0
0
。很显然就杜教筛了。
考虑怎样求的前缀和,这个东西显然可以拉格朗日插值,是一个
m+1
m
+
1
次的多项式。
因为
nmd=nm∗d
n
m
d
=
n
m
∗
d
,所以所有需要用到的插值位置都是
n
n
的约数。
对于前面一部分约数,基本都是连续的,可以连续的预处理,时间复杂度是的;而后面的约数比较稀疏,考虑使用拉个朗日插值法,复杂度是
O(km)
O
(
k
m
)
的,
k
k
为查询次数。
考虑预处理的数,这样只有大约
350
350
个约数没有被预处理出来。第一部分的大概需要
(3∗106)∗20=6∗107
(
3
∗
10
6
)
∗
20
=
6
∗
10
7
次循环,而第二部分有
350∗(2∗105)=7∗107
350
∗
(
2
∗
10
5
)
=
7
∗
10
7
次。
然后我就被卡掉了……大概时间是标程的1.5倍吧,不想卡了。
代码:
#include <iostream>
#include <cstdio>
#include <cmath>
#include <map>
#define LL long long
const int maxn=3e6;
const int maxm=2e5+7;
const int mod=998244353;
using namespace std;
int f[maxn+7],a[maxn+7],l[maxm],r[maxm],njc[maxm],jc[maxm];
int s[1007];
int x,y;
int n,m,cnt;
int prime[maxn+7],not_prime[maxn+7];
map <int,int> h;
int power(int x,int y)
{
if (y==1) return x;
int c=power(x,y/2);
c=((LL)c*c)%mod;
if (y&1) c=((LL)c*x)%mod;
return c;
}
void getmul(int n)
{
f[1]=1;
for (int i=2;i<=n;i++)
{
if (!not_prime[i])
{
prime[++cnt]=i;
f[i]=-1;
}
for (int j=1;j<=cnt;j++)
{
if (i*prime[j]>n) break;
not_prime[i*prime[j]]=1;
if (i%prime[j]==0)
{
f[i*prime[j]]=0;
break;
}
else f[i*prime[j]]=-f[i];
}
}
int d;
for (int i=1;i<=n;i++)
{
d=(a[i]+mod-a[i-1])%mod;
if (f[i]<0) f[i]=mod+f[i];
f[i]=(f[i-1]+(LL)f[i]*d)%mod;
}
}
int calc(int n)
{
int ans=0;
l[0]=1; r[m+3]=1;
for (int i=1;i<=m+2;i++) l[i]=((LL)l[i-1]*(n-i))%mod;
for (int i=m+2;i>0;i--) r[i]=((LL)r[i+1]*(n-i))%mod;
for (int i=1;i<=m+2;i++)
{
ans=(ans+(LL)a[i]*l[i-1]%mod*r[i+1]%mod*njc[m+2-i]%mod*jc[i-1]%mod)%mod;
}
return ans;
}
int get(int x)
{
if (x<=maxn) return a[x];
return s[n/x];
}
int getsum(int n)
{
if (n<=maxn) return f[n];
int c=h[n];
if (c) return c;
int sum=0;
int x=0,y=1;
for (int i=2,last;i<=n;i=last+1)
{
last=n/(n/i);
x=get(last);
sum=(sum+(LL)(x+mod-y)%mod*getsum(n/i)%mod)%mod;
y=x;
}
c=(1-sum+mod)%mod;
h[n]=c;
return c;
}
int main()
{
freopen("calc.in","r",stdin);
freopen("calc.out","w",stdout);
scanf("%d%d",&n,&m);
for (int i=1;i<=maxn;i++) a[i]=(LL)(a[i-1]+power(i,m))%mod;
jc[0]=njc[0]=1;
for (int i=1;i<=m+2;i++) jc[i]=(LL)jc[i-1]*power(i,mod-2)%mod;
for (int i=1;i<=m+2;i++) njc[i]=(LL)njc[i-1]*power(mod-i,mod-2)%mod;
getmul(maxn);
for (int i=1,last;i<=n;i=last+1)
{
last=n/(n/i);
if (last>=maxn) s[n/last]=calc(last);
}
printf("%d",getsum(n));
}
std:
#include <bits/stdc++.h>
using namespace std;
const int N=3000005,M=200005,mo=998244353;
typedef long long LL;
int n,m,tot,p[216900],f[N],ans,Inv[M],pre[M],suck[M],zil[M];
int sb1[M],sb2[M];
short mu[N];
bool bz[N];
map <int,int> h,sb;
int quick(int x,int y)
{
if (!y) return 1;
int s=quick(x,y>>1); s=(LL)s*s%mo;
if (y&1) s=(LL)s*x%mo;
return s;
}
int sp(int n)
{
if (n<N) return p[n];
if (sb[n]) return sb[n];
int s=0,t=1;
for (int i=1;i<=m+2;i++) pre[i]=n-i;
suck[m+3]=1;
for (int i=m+2;i;i--) suck[i]=(LL)suck[i+1]*pre[i]%mo;
pre[0]=1;
for (int i=1;i<=m+2;i++)
{
pre[i]=(LL)pre[i-1]*pre[i]%mo;
s=(s+(LL)zil[i]*pre[i-1]%mo*suck[i+1]%mo*p[i])%mo;
}
return sb[n]=s;
}
int Van(int Boy)
{
return (Boy<M)?sb1[Boy]:sb2[n/Boy];
}
int calc(int n)
{
if (n<N) return (f[n]+mo)%mo;
if (h[n]) return h[n];
int s=1,i,la,j,now,yeah;
for (la=1,i=2;i<=n;i=j+1)
{
j=n/(n/i);
yeah=calc(n/i); now=Van(j);
s=(s-(LL)(now-la)*yeah)%mo;
la=now;
}
if (s<0) s+=mo;
return h[n]=s;
}
int main()
{
freopen("calc.in","r",stdin);
freopen("calc.out","w",stdout);
scanf("%d%d",&n,&m);
mu[1]=f[1]=Inv[1]=1;
for (int i=2;i<M;i++) Inv[i]=(LL)Inv[mo%i]*(mo-mo/i)%mo;
Inv[0]=1;
for (int i=1;i<M;i++) Inv[i]=(LL)Inv[i-1]*Inv[i]%mo;
for (int i=1,sig=(m&1)?1:-1;i<=m+2;i++,sig=-sig) zil[i]=(LL)sig*Inv[m+2-i]*Inv[i-1]%mo;
for (int i=2;i<N;i++)
{
if (!bz[i])
{
p[tot++]=i; mu[i]=-1; f[i]=quick(i,m);
}
for (int j=0;j<tot && i*p[j]<N;j++)
{
bz[i*p[j]]=1;
f[i*p[j]]=(LL)f[i]*f[p[j]]%mo;
if (i%p[j]==0)
{
mu[i*p[j]]=0; break;
}
mu[i*p[j]]=-mu[i];
}
}p[0]=0;
memset(bz,0,sizeof(bz));
for (int i=1,j;i<N && i<=n;i=j+1)
{
j=n/(n/i);
if (j>=N) break;
bz[j]=1;
}
for (int i=1;i<=m+2;i++) p[i]=(p[i-1]+f[i])%mo;
for (int i=1,j=0;i<N;i++)
{
j=(j+f[i])%mo;
if (bz[i])
{
if (i<M) sb1[i]=j;else sb2[n/i]=j;
}
f[i]=(f[i]*mu[i]+f[i-1])%mo;
}
for (int i=1,j;i<=n;i=j+1)
{
j=n/(n/i);
if (j>=N) sb2[n/j]=sp(j);
}
printf("%d\n",calc(n));
return 0;
}