题目传送门:https://www.luogu.org/problem/show?pid=3768
题目分析:我们来看一下,原先题目要我们求:
∑i=1n∑j=1nijgcd(i,j)
我们变形一下,将j只枚举到i。则原式转化成:
2∑i=1ni∑j=1ijgcd(i,j)−∑i=1ni3
由于我们将i=j时候的答案算了两次,所以最后要减去i*i*gcd(i,i),即i的立方。
接下来我们重点关注左边那个部分。我们把枚举j改为枚举i的约数d,则得到:
2∑i=1ni∑d|id2∑j=1idj[gcd(id,j=1)]−∑i=1ni3
我们又发现,其实可以不用枚举j。我们其实就是在求和 id 互质的数的和,我们将那一段改成 id∗ϕ(id)+[id=1]2 ,然后把它和前面的东西化简。于是我们机智地发现 [id=1] 其实就是要求i=d,这刚好和右边减号后面的部分抵消。于是最后我们得到了一条很简洁的式子:
∑i=1ni2∑d|idϕ(id)
我们换一下枚举的顺序,先枚举因数d,然后再枚举 j=id ,得到:
∑d=1nd3∑j=1⌊nd⌋j2ϕ(j)
我们记右边为 f(⌊nd⌋) ,那么对于不同的d, ⌊nd⌋ 顶多只有 2n√ 个取值。于是我们对这一部分跑杜教筛,左边做部分和(即下底函数分块)。用线性筛预处理出前 n23 个,时间复杂度为 O(n23∗log2n) 。
(PS:不知杜教筛为何物的同学可以去看一下 这篇博客,写得非常棒,简单易懂)
以下附杜教筛的推导过程:
我们现在要求
f=Id2⋅ϕ
(Id指单位函数,即Id(i)=i,
⋅
为乘法)的前缀和。我们可以发现
f∗Id2=Id3
(其中
∗
指狄利克雷卷积),于是令j等于
∑i=1ni3=∑i=1n∑d|id2ϕ(d)(id)2=∑j=1nj2∑d=1⌊nj⌋d2ϕ(d)
故:
f(n)=∑i=1ni3−∑j=1nj2f(⌊nj⌋)
另外,我们知道:
13+23+……n3=(1+2+……+n)2=(n(n+1)2)2
这可以O(1)地算出来。
由于n达到 1010 ,两个数相乘之前不取模会炸long long,但取了模之后做除法又会出错。我们可以写扩展gcd,或者像我一样写一个很水的分类讨论……
于是本题完美解决。
CODE:
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;
const int maxn=4600001;
const int M=300509;
typedef long long LL;
struct data
{
LL num;
LL val;
} Hash[M];
LL f[maxn];
bool vis[maxn];
int prime[maxn];
int cur=0;
LL p,n,ans=0;
void Make()
{
f[1]=1;
for (int i=2; i<maxn; i++)
{
if (!vis[i]) f[i]=i-1,prime[++cur]=i;
for (int j=1; j<=cur && i*prime[j]<maxn; j++)
{
int k=i*prime[j];
vis[k]=true;
if (i%prime[j]) f[k]=f[i]*f[ prime[j] ];
else
{
f[k]=f[i]*prime[j];
break;
}
}
}
for (int i=2; i<maxn; i++) f[i]=f[i]*(long long)i%p*(long long)i%p;
for (int i=2; i<maxn; i++)
{
f[i]+=f[i-1];
if (f[i]>p) f[i]-=p;
}
}
LL Get3(LL x)
{
LL y;
if (x%2LL) y=((x+1LL)/2LL%p)*(x%p)%p;
else y=(x/2LL%p)*((x+1LL)%p)%p;
y=(y*y)%p;
return y;
}
LL Find(LL x)
{
int y=x%M;
while ( Hash[y].num && Hash[y].num!=x ) y=(y+1)%M;
if (!Hash[y].num) return -1;
return Hash[y].val;
}
LL Get2(LL x)
{
LL y=x+1LL,z=2LL*x+1LL;
bool f2=false,f3=false;
if ( !f2 && !(x%2LL) ) x/=2LL,f2=true;
if ( !f3 && !(x%3LL) ) x/=3LL,f3=true;
if ( !f2 && !(y%2LL) ) y/=2LL,f2=true;
if ( !f3 && !(y%3LL) ) y/=3LL,f3=true;
if ( !f2 && !(z%2LL) ) z/=2LL,f2=true;
if ( !f3 && !(z%3LL) ) z/=3LL,f3=true;
x%=p;
y%=p;
z%=p;
return x*y%p*z%p;
}
void Push(LL x,LL v)
{
int y=x%M;
while (Hash[y].num) y=(y+1)%M;
Hash[y].num=x;
Hash[y].val=v;
}
LL Dfs(LL x)
{
if (x<maxn) return f[x];
LL temp=Find(x);
if (temp!=-1) return temp;
else temp=Get3(x);
LL last;
for (LL i=2; i<=x; i=last+1)
{
last=x/(x/i);
LL Dec=Dfs(x/i);
Dec=( Get2(last)-Get2(i-1)+p )%p*Dec%p;
temp=(temp-Dec+p)%p;
}
Push(x,temp);
return temp;
}
int main()
{
freopen("c.in","r",stdin);
freopen("c.out","w",stdout);
cin>>p>>n;
Make();
LL last;
for (LL i=1; i<=n; i=last+1)
{
last=n/(n/i);
ans=(ans+ ( Get3(last)-Get3(i-1)+p )%p*Dfs(n/i)%p )%p;
}
cout<<ans<<endl;
return 0;
}