题意简述
给出
n
个数
定义
yi,j=xi∗xj(modp)
其中
p=359999
求
gcd(ya,b,yc,d,ye,f)=1
的三元组数目。
数据范围
1≤T≤3
1≤∑n≤106
1≤ai≤106
思路
这个题分为两部分,第一部分是求出 y 每种数的个数,第二个部分是反演求出答案。
先看第一部分。
一个比较显然的思路是利用原根乘法转化为加法,使用FFT求得。
但是
但是观察到
p=599×601
是两个质数的乘积,
我们可以令
f[i][j]
表示
x≡gi(mod599)x≡gj(mod601)
的方案数。
我们要求的就是
f2
,二维FFT。
具体做法是对每一维进行FFT,点值表达平方,对每一维进行IFFT。
需要注意的是,有可能会出现模某一个质数
x≡0
的情况。
我们需要对这种情况单独讨论,做法是对另一维做FFT。
模两个质数都为
0
的,我们直接记录下来。
这样就可以求出每一种数的个数。
第二个部分。
计算
令
f(n)
表示gcd是
n
的方案数,
g(n)=∑n|d,d<pf(n)
反演得
f(n)=∑n|d,d<pf(d)μ(dn)
求
f(1)
即为答案。
复杂度 O(plogp)
UPD17.3.26:二维FFT的另一种做法,把一维的最大值*2,作为第二维的位权,压成一维做FFT,这样可以减不少常数。然而被卡精度了QwQ。应该写NTT的……
UPD17.3.27:预处理单位复根就可以了……
代码
#include<cstdio>
#include<cmath>
#include<cstring>
#include<iostream>
using namespace std;
#define MAXL 2097160
const double pi=acos(-1);
const int mo=1000000007;
const int p=359999;
const int p1=601;
const int p2=599;
const int phi1=600;
const int phi2=598;
const int M=300;
const int g=7;
struct C{
double x,y;
C(double _x=0,double _y=0)
{
x=_x,y=_y;
}
void operator = (const int &n1)
{
x=n1;
y=0;
}
C operator * (const C &n1)
{
return C(x*n1.x-y*n1.y,x*n1.y+y*n1.x);
}
C operator + (const C &n1)
{
return C(x+n1.x,y+n1.y);
}
C operator - (const C &n1)
{
return C(x-n1.x,y-n1.y);
}
double real()
{
return x;
}
void operator /= (const int &n1)
{
x/=n1,y/=n1;
}
void operator *=(const C &n1)
{
double a=x,b=y;
x=a*n1.x-b*n1.y;
y=a*n1.y+b*n1.x;
}
};
int T,n,v,w1,w2;
int gp1[610],gp2[610],regp1[610],regp2[610],num1[610],num2[610],num[610][610],zero,zero1,zero2;
long long sum[360010],gg[360010],ans;
int len,ti;
int r[MAXL];
C a[MAXL],wn[MAXL];
void fft(C *a,int f)
{
for (int i=0;i<len;i++)
if (i<r[i])
swap(a[i],a[r[i]]);
for (int i=1,t=len;i<len;i<<=1,t>>=1)
for (int j=0;j<len;j+=(i<<1))
{
C w=1;
for (int k=0;k<i;k++,f==1 ? w=wn[k*t] : w=C(wn[k*t].x,-wn[k*t].y))
{
C x=a[j+k],y=w*a[j+k+i];
a[j+k]=x+y,a[j+k+i]=x-y;
}
}
}
int cnt;
int prime[360010],mu[360010];
bool not_prime[360010];
void sieve(int n)
{
mu[1]=1;
for (int i=2;i<=n;i++)
{
if (!not_prime[i])
{
prime[++cnt]=i;
mu[i]=-1;
}
for (int j=1;j<=cnt&&i*prime[j]<=n;j++)
{
not_prime[i*prime[j]]=1;
if (i%prime[j]==0)
{
mu[i*prime[j]]=0;
break;
}
mu[i*prime[j]]=-mu[i];
}
}
}
int main()
{
sieve(p);
for (int i=0,u=1;i<p1-1;u=u*g%p1,i++)
{
gp1[u]=i;
regp1[i]=u;
}
for (int i=0,u=1;i<p2-1;u=u*g%p2,i++)
{
gp2[u]=i;
regp2[i]=u;
}
scanf("%d",&T);
while (T--)
{
scanf("%d",&n);
zero=0,zero1=0,zero2=0;
memset(num1,0,sizeof(num1));
memset(num2,0,sizeof(num2));
memset(num,0,sizeof(num));
memset(sum,0,sizeof(sum));
for (int i=1;i<=n;i++)
{
scanf("%d",&v);
if (v%p==0)
++zero;
else
{
if (v%p1!=0)
++num1[gp1[v%p1]];
else
++zero1;
if (v%p2!=0)
++num2[gp2[v%p2]];
else
++zero2;
if (v%p1!=0&&v%p2!=0)
++num[gp1[v%p1]][gp2[v%p2]];
}
}
len=2048,ti=11;
for (int i=0;i<len;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(ti-1));
for (int i=0;i<len;i++)
wn[i]=C(cos(pi*i/len),sin(pi*i/len));
for (int i=0;i<p1;i++)
a[i]=num1[i];
for (int i=p1;i<len;i++)
a[i]=0;
fft(a,1);
for (int i=0;i<len;i++)
a[i]=a[i]*a[i];
fft(a,-1);
for (int i=0;i<len;i++)
a[i]/=len;
for (int i=0;i<len;i++)
sum[regp1[(i%phi1)]*p2*M%p]+=(long long)(a[i].real()+0.5);
for (int i=0;i<p2;i++)
a[i]=num2[i];
for (int i=p2;i<len;i++)
a[i]=0;
fft(a,1);
for (int i=0;i<len;i++)
a[i]=a[i]*a[i];
fft(a,-1);
for (int i=0;i<len;i++)
a[i]/=len;
for (int i=0;i<len;i++)
sum[regp2[(i%phi2)]*p1*M%p]+=(long long)(a[i].real()+0.5);
len=2097152,ti=21;
for (int i=0;i<len;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(ti-1));
for (int i=0;i<len;i++)
wn[i]=C(cos(pi*i/len),sin(pi*i/len));
memset(a,0,sizeof(a));
for (int i=0;i<p1;i++)
for (int j=0;j<p2;j++)
if (num[i][j])
a[i+2*p1*j]=num[i][j];
fft(a,1);
for (int i=0;i<len;i++)
a[i]=a[i]*a[i];
fft(a,-1);
for (int i=0;i<len;i++)
a[i]/=len;
for (int i=0;i<len;i++)
{
w1=i%(2*p1);
w2=i/(2*p1);
sum[(regp1[(w1%phi1)]*p2*M+regp2[(w2%phi2)]*p1*M)%p]+=(long long)(a[i].real()+0.5);
sum[regp1[(w1%phi1)]*p2*M%p]-=(long long)(a[i].real()+0.5);
sum[regp2[(w2%phi2)]*p1*M%p]-=(long long)(a[i].real()+0.5);
}
memset(gg,0,sizeof(gg));
for (int i=1;i<p;i++)
for (int j=i;j<p;j+=i)
gg[i]=(gg[i]+sum[j])%mo;
gg[0]=(1LL*zero*n*2-1LL*zero*zero+1LL*zero1*zero2*2)%mo;
ans=0;
for (int i=1;i<p;i++)
{
ans=(ans+mu[i]*gg[i]*gg[i]%mo*gg[i])%mo;
ans=(ans+mu[i]*gg[i]*gg[i]%mo*gg[0]*3)%mo;
ans=(ans+mu[i]*gg[i]*gg[0]%mo*gg[0]*3)%mo;
}
printf("%lld\n",(ans+mo)%mo);
}
return 0;
}