题意
题目链接
设
n
=
p
1
a
1
…
p
k
a
k
n=p_1^{a_1}\dots p_k^{a_k}
n=p1a1…pkak,则定义
f
(
n
)
=
a
1
+
⋯
+
a
k
f(n)=a_1+\dots+a_k
f(n)=a1+⋯+ak。给出
n
n
n,求
∑
i
=
1
n
f
(
i
!
)
\sum_{i=1}^nf(i!)
i=1∑nf(i!)
n
≤
1
0
10
n\le10^{10}
n≤1010
分析
显然要求的是
∑
p
∈
P
∑
k
≥
1
∑
i
=
1
n
⌊
n
p
k
⌋
\sum_{p\in \mathbb{P}}\sum_{k\ge 1}\sum_{i=1}^n\lfloor\frac{n}{p^k}\rfloor
p∈P∑k≥1∑i=1∑n⌊pkn⌋
=
∑
p
∈
P
∑
k
≥
1
(
n
+
1
)
⌊
n
p
k
⌋
−
p
k
s
u
m
(
⌊
n
p
k
⌋
)
=\sum_{p\in \mathbb{P}}\sum_{k\ge 1}(n+1)\lfloor\frac{n}{p^k}\rfloor-p^ksum(\lfloor\frac{n}{p^k}\rfloor)
=p∈P∑k≥1∑(n+1)⌊pkn⌋−pksum(⌊pkn⌋)
其中
s
u
m
(
x
)
=
x
(
x
+
1
)
2
sum(x)=\frac{x(x+1)}{2}
sum(x)=2x(x+1)
对于不超过
n
\sqrt n
n的素数可以暴力计算贡献,剩下部分就是
∑
p
∈
P
(
n
+
1
)
⌊
n
p
⌋
−
p
s
u
m
(
⌊
n
p
⌋
)
\sum_{p\in \mathbb{P}}(n+1)\lfloor\frac{n}{p}\rfloor-psum(\lfloor\frac{n}{p}\rfloor)
p∈P∑(n+1)⌊pn⌋−psum(⌊pn⌋)
直接用min25筛分块即可。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
typedef long long LL;
const int N=200005;
const int MOD=998244353;
const int B=100000;
LL n,w[N];
int tot,prime[N],s[N],m,g[N],h[N],id1[N],id2[N];
bool not_prime[N];
void get_prime(int n)
{
for (int i=2;i<=n;i++)
{
if (!not_prime[i]) prime[++tot]=i,s[tot]=(s[tot-1]+i)%MOD;
for (int j=1;j<=tot&&i*prime[j]<=n;j++)
{
not_prime[i*prime[j]]=1;
if (i%prime[j]==0) break;
}
}
}
int sum(LL n)
{
n%=MOD;
return (LL)(MOD+1)/2*n%MOD*(n+1)%MOD;
}
int get(LL p)
{
return ((LL)(n+1)%MOD*(n/p%MOD)-(LL)p*sum(n/p))%MOD;
}
int find1(LL x)
{
return x<=B?h[id1[x]]:h[id2[n/x]];
}
int find2(LL x)
{
return x<=B?g[id1[x]]:g[id2[n/x]];
}
int main()
{
scanf("%lld",&n);
get_prime(B);
for (LL i=1,last;i<=n;i=last+1)
{
last=n/(n/i);w[++m]=n/i;
g[m]=sum(n/i)-1;h[m]=n/i-1;
if (n/i<=B) id1[n/i]=m;
else id2[last]=m;
}
for (int j=1;j<=tot;j++)
for (int i=1;i<=m&&(LL)prime[j]*prime[j]<=w[i];i++)
{
int k=w[i]/prime[j]<=B?id1[w[i]/prime[j]]:id2[n/(w[i]/prime[j])];
(g[i]-=(LL)prime[j]*(g[k]-s[j-1])%MOD)%=MOD;
h[i]-=h[k]-j+1;
}
int ans=0;
for (int i=1;i<=tot;i++)
for (LL j=(LL)prime[i]*prime[i];j<=n;j=(LL)j*prime[i])
(ans+=get(j))%=MOD;
for (LL i=2,last;i<=n;i=last+1)
{
last=n/(n/i);
(ans+=(LL)(n+1)%MOD*(n/i%MOD)%MOD*(find1(last)-find1(i-1))%MOD)%=MOD;
(ans-=(LL)sum(n/i)*(find2(last)-find2(i-1))%MOD)%=MOD;
}
printf("%d\n",(ans+MOD)%MOD);
return 0;
}