题目链接
题意分析
\[\sum_{i=1}^n\sum_{j=1}^mμ^2(gcd(i,j))\]
\[=\sum_{i=1}^n\sum_{j=1}^m\sum_{d=1}^{min(n,m)}μ^2(d)[gcd(i,j)==d]\]
\[=\sum_{d=1}^{min(n,m)}μ^2(d)\sum_{i=1}^{\lfloor\frac{n}{d}\rfloor}\sum_{j=1}^{\lfloor\frac{m}{d}\rfloor}[gcd(i,j)==1]\]
\[=\sum_{d=1}^{min(n,m)}μ^2(d)\sum_{i=1}^{\lfloor\frac{n}{d}\rfloor}\sum_{j=1}^{\lfloor\frac{m}{d}\rfloor}\sum_{d|gcd(i,j)}μ(d)\]
\[=\sum_{d=1}^{min(n,m)}μ^2(d)\sum_{k=1}^{min(\lfloor\frac{n}{d}\rfloor,\lfloor\frac{m}{d}\rfloor)}μ(d)\lfloor\frac{n}{kd}\rfloor\lfloor\frac{m}{kd}\rfloor\]
令\(T=kd\)
\[\sum_{T=1}^{min(n,m)}\lfloor\frac{n}{T}\rfloor\lfloor\frac{m}{T}\rfloor\sum_{d|T}μ^2(d)μ(\frac{T}{d})\]
由于只有完全平方数
\[\sum_{d|T}μ^2(d)μ(\frac{T}{d})=μ(\sqrt{T})\]
其余的话\(d\)以及\(\frac{T}{d}\)会互相抵消
所以我们空间时间均优化到\(O(\sqrt{n})\)
CODE:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<cstdlib>
#include<string>
#include<queue>
#include<map>
#include<stack>
#include<list>
#include<set>
#include<deque>
#include<vector>
#include<ctime>
#define ll long long
#define inf 0x7fffffff
#define N 500008
#define IL inline
#define M 5008611
#define maxn 5000002
#define mod 998244353
#define R register
using namespace std;
template<typename T>IL void read(T &_)
{
T __=0,___=1;char ____=getchar();
while(!isdigit(____)) {if(____=='-') ___=0;____=getchar();}
while(isdigit(____)) {__=(__<<1)+(__<<3)+____-'0';____=getchar();}
_=___ ? __:-__;
}
/*-------------OI使我快乐-------------*/
ll n,m,ans;
ll tot;
ll prime[M],mul[M];
ll sum[M];
bool mark[M];
IL void work()
{
mul[1]=1;
for(R ll i=2;i<=maxn;++i)
{
if(!mark[i]) {prime[++tot]=i;mul[i]=-1;}
for(R ll j=1;j<=tot&&prime[j]*i<=maxn;++j)
{
mark[prime[j]*i]=1;
if(i%prime[j]==0)
{
mul[prime[j]*i]=0;break;
}
else mul[prime[j]*i]=-mul[i];
}
}
for(R ll i=1;i<=maxn;++i) sum[i]=(sum[i-1]+(mul[i]+mod)%mod)%mod;
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
work();
read(n);read(m);
for(R ll l=1,r=0;l<=min(n,m);l=r+1)
{
r=min(n/(n/l),m/(m/l));
ans=((ans+((n/l)%mod)*((m/l)%mod)%mod*((sum[(ll)sqrt(r)]-sum[(ll)sqrt(l-1)]+mod)%mod)%mod)%mod+mod)%mod;
}
printf("%lld\n",ans);
// fclose(stdin);
// fclose(stdout);
return 0;
}