题目大意:
令
f
k
(
n
)
f_k(n)
fk(n)表示长度为
k
k
k的序列,每个元素在
[
1
,
n
]
[1,n]
[1,n],并且
gcd
\gcd
gcd为
1
1
1的数量。求:
∑
i
=
1
n
f
k
(
i
)
,
n
≤
1
0
9
,
k
≤
1
0
5
\sum_{i=1}^n f_k(i),\ \ \ n\le10^9,k\le10^5
∑i=1nfk(i), n≤109,k≤105
题解:
A
n
s
=
∑
i
=
1
n
f
k
(
i
)
=
∑
i
=
1
n
∑
j
=
1
i
⌊
i
j
⌋
k
μ
(
j
)
=
∑
j
=
1
n
μ
(
j
)
∑
i
=
1
n
⌊
i
j
⌋
k
=
∑
j
=
1
n
[
(
∑
i
=
1
⌊
n
j
⌋
−
1
i
k
j
)
+
⌊
n
j
⌋
k
(
n
−
⌊
n
j
⌋
j
+
1
)
]
=
∑
j
=
1
n
μ
(
j
)
(
j
S
k
(
⌊
n
j
⌋
−
1
)
+
⌊
n
j
⌋
k
(
n
+
1
)
−
⌊
n
j
⌋
k
+
1
j
)
=
∑
j
=
1
n
μ
(
j
)
j
S
k
(
⌊
n
j
⌋
−
1
)
+
μ
(
j
)
⌊
n
j
⌋
k
(
n
+
1
)
−
μ
(
j
)
j
⌊
n
j
⌋
k
+
1
Ans=\sum_{i=1}^nf_k(i)=\sum_{i=1}^n \sum_{j=1}^i\left\lfloor\frac ij\right\rfloor^k\mu(j)=\sum_{j=1}^n\mu(j)\sum_{i=1}^n\left\lfloor\frac ij\right\rfloor^k\\ =\sum_{j=1}^n\left[\left(\sum_{i=1}^{\left\lfloor\frac nj\right\rfloor-1}i^kj\right)+\left\lfloor\frac nj\right\rfloor^k\left(n-\left\lfloor\frac nj\right\rfloor j+1\right)\right]\\ =\sum_{j=1}^n\mu(j)\left(jS_k\left(\left\lfloor\frac nj\right\rfloor-1\right)+\left\lfloor\frac nj\right\rfloor^k(n+1)-\left\lfloor\frac nj\right\rfloor^{k+1}j\right)\\ =\sum_{j=1}^n\mu(j)jS_k\left(\left\lfloor\frac nj\right\rfloor-1\right)+\mu(j)\left\lfloor\frac nj\right\rfloor^k(n+1)-\mu(j)j\left\lfloor\frac nj\right\rfloor^{k+1}
Ans=i=1∑nfk(i)=i=1∑nj=1∑i⌊ji⌋kμ(j)=j=1∑nμ(j)i=1∑n⌊ji⌋k=j=1∑n⎣⎢⎡⎝⎜⎛i=1∑⌊jn⌋−1ikj⎠⎟⎞+⌊jn⌋k(n−⌊jn⌋j+1)⎦⎥⎤=j=1∑nμ(j)(jSk(⌊jn⌋−1)+⌊jn⌋k(n+1)−⌊jn⌋k+1j)=j=1∑nμ(j)jSk(⌊jn⌋−1)+μ(j)⌊jn⌋k(n+1)−μ(j)j⌊jn⌋k+1
其中:
S
k
(
n
)
=
∑
i
=
1
n
i
k
S_k(n)=\sum_{i=1}^n i^k
Sk(n)=i=1∑nik
然后对
μ
(
j
)
\mu(j)
μ(j)和
μ
(
j
)
j
\mu(j)j
μ(j)j求杜教筛,对
S
k
(
n
)
S_k(n)
Sk(n)做插值即可。
取
b
l
o
c
k
_
s
i
z
e
=
n
k
\mathrm{block\_size}=\sqrt{nk}
block_size=nk,可以做到
O
(
n
2
3
+
n
k
)
O\left(n^{\frac23}+\sqrt{nk}\right)
O(n32+nk)
s
t
d
\mathrm{std}
std本意是要写一个多点插值把后半部分做到
O
(
n
l
g
k
)
O\left(\sqrt{n}lgk\right)
O(nlgk)来着,但是数据出小了,就暴过去了……
#include<bits/stdc++.h>
#define gc getchar()
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define mod 998244353
#define db double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define N 10000010
#define K 100020
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
typedef unordered_map<int,int> mpii;
typedef unordered_map<int,lint> mpil;
mpil savg;lint sg[N];int fac[K],facinv[K];
mpii savh;int sh[N],k,mu[N],ik[N],sk[N];
bool np[N];int p[N],pre[K],suf[K],y[K];
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%mod) (k&1)?ans=(lint)ans*x%mod:0;return ans; }
inline int prelude(int n)
{
mu[1]=1,sg[1]=mu[1]*1,sh[1]=mu[1],ik[1]=1,sk[1]=1;
for(int i=2,c=0;i<=n;i++)
{
if(!np[i]) p[++c]=i,mu[i]=-1,ik[i]=fast_pow(i,k);
sg[i]=sg[i-1]+mu[i]*i,sh[i]=sh[i-1]+mu[i],
sk[i]=sk[i-1]+ik[i],(sk[i]>=mod?sk[i]-=mod:0);
for(int j=1,u=n/i;j<=c&&p[j]<=u;j++)
{
int x=p[j]*i;np[x]=1,ik[x]=(lint)ik[i]*ik[p[j]]%mod;
if(i%p[j]==0) { mu[x]=0;break; } else mu[x]=-mu[i];
}
}
return 0;
}
inline int prelude2(int n)
{
rep(i,1,n) y[i]=y[i-1]+fast_pow(i,k),(y[i]>=mod?y[i]-=mod:0);
rep(i,fac[0]=1,n) fac[i]=(lint)fac[i-1]*i%mod;
facinv[n]=fast_pow(fac[n],mod-2);
for(int i=n-1;i>=0;i--) facinv[i]=(i+1ll)*facinv[i+1]%mod;
return 0;
}
inline lint g(int n)//mu(i)*i
{
if(n<N) return sg[n];lint ans=0;
if(savg.count(n)) return savg[n];
for(int s=2,t;s<=n;s=t+1) t=n/(n/s),ans+=(s+t)*(t-s+1ll)/2*g(n/s);
return savg[n]=1-ans;
}
inline int h(int n)//mu(i)
{
if(n<N) return sh[n];int ans=0;
if(savh.count(n)) return savh[n];
for(int s=2,t;s<=n;s=t+1) t=n/(n/s),ans+=h(n/s)*(t-s+1);
return savh[n]=1-ans;
}
inline int g(int l,int r) { return ((g(r)-g(l-1))%mod+mod)%mod; }
inline int h(int l,int r) { return (h(r)-h(l-1)+mod)%mod; }
inline int S(int x)
{
if(x<N) return sk[x];int n=k+3;static lint xs[3],ans;
pre[0]=suf[n+1]=1,xs[0]=1,xs[1]=-1,ans=0;
for(int i=1;i<=n;i++) pre[i]=(lint)pre[i-1]*(x-i)%mod;
for(int i=n;i>=1;i--) suf[i]=(lint)suf[i+1]*(x-i)%mod;
for(int i=1;i<=n;i++)
ans+=xs[(n-i)&1]*y[i]*pre[i-1]%mod*suf[i+1]%mod*facinv[i-1]%mod*facinv[n-i]%mod;
return int((ans%mod+mod)%mod);
}
inline int IK(int i) { if(i<N) return ik[i];return fast_pow(i,k); }
inline int qs(int n,int kk=k) { int ans=0;rep(i,1,n) (ans+=fast_pow(i,kk))%=mod;return ans; }
int main()
{
int n=inn();k=inn();lint ans=0;prelude(N-1),prelude2(K-1);
for(int l=1,r,t;l<=n;l=r+1) r=n/(n/l),t=n/l,
ans+=(g(l,r)*(S(t-1)-(lint)IK(t)*t%mod)+h(l,r)*(n+1ll)%mod*IK(t)%mod)%mod;
return !printf("%lld\n",(ans%mod+mod)%mod);
}