逆天题
题目概述
题解
首先,我们考虑这东西该怎么用生成函数进行表示。
单纯子集和的生成函数显然就是
∏
i
=
1
m
n
(
1
+
x
i
)
\prod_{i=1}^{mn}(1+x^i)
∏i=1mn(1+xi),如果要求出将子集和取模
n
n
n后得到的
f
i
f_i
fi的话。
有一个经典的技巧就是给整个生成函数取模上一个
x
n
−
1
x^n-1
xn−1的话,就可以将所有的
x
k
n
+
b
x^{kn+b}
xkn+b的系数叠加到
x
b
x^b
xb上面。
这样,我们就能够求出
f
f
f的生成函数了,也就是
F
=
∏
i
=
0
n
−
1
(
1
+
x
i
)
m
(
m
o
d
x
n
−
1
)
F=\prod_{i=0}^{n-1}(1+x^{i})^m \pmod{x^n-1}
F=∏i=0n−1(1+xi)m(modxn−1)。
那么我们答案的
∑
f
2
\sum f^2
∑f2又该怎么求了。
可以观察到这相当于一个子集和减去另一个子集和模
n
n
n为
0
0
0的方案数。
显然,我们子集和的函数是高度对称的,也就是说,减去一个子集需要乘的生成函数就是加上它的生成函数。
所以我们的答案为
[
x
0
]
∏
i
=
0
n
−
1
(
1
+
x
i
)
2
m
(
m
o
d
x
n
−
1
)
[x^0]\prod_{i=0}^{n-1}(1+x^i)^{2m} \pmod{x^n-1}
[x0]∏i=0n−1(1+xi)2m(modxn−1)。
好的,我们又该怎么计算这东西呢?
显然,这个
2
m
2m
2m次方可以做了
D
F
T
DFT
DFT后再次方计算。
由于我们取模的是
x
n
−
1
x^n-1
xn−1,显然
D
F
T
DFT
DFT也就是通过单位根计算的。
可以发现,
1
+
x
i
1+x^i
1+xi再
D
F
T
DFT
DFT后对
x
k
x^k
xk的贡献为
1
+
w
n
i
k
1+w^{ik}_n
1+wnik。
那我们上面的式子
D
F
T
DFT
DFT后,可以得到
f
i
^
=
∏
k
=
0
n
−
1
(
1
+
w
n
i
k
)
\widehat{f_i}=\prod_{k=0}^{n-1}(1+w^{ik}_n)
fi
=∏k=0n−1(1+wnik)。
怎么算出这东西的真实值呢?考虑
x
n
−
1
=
∏
(
x
−
w
n
i
)
x^n-1=\prod(x-w_{n}^i)
xn−1=∏(x−wni)。
带入
x
=
−
1
x=-1
x=−1,可以得到
f
i
^
=
2
[
2
∣
n
(
n
,
i
)
]
(
n
,
i
)
\widehat{f_i}=2[2|\frac{n}{(n,i)}]^{(n,i)}
fi
=2[2∣(n,i)n](n,i)。
再对这东西乘上
2
m
2m
2m次方,也就是我们答案的
F
^
\widehat{F}
F
。
答案再
I
D
F
T
IDFT
IDFT回去可以得到
A
n
s
=
1
n
∑
i
=
0
n
−
1
f
i
^
Ans=\frac{1}{n}\sum_{i=0}^{n-1} \widehat{f_i}
Ans=n1∑i=0n−1fi
显然,这东西只与
(
n
,
i
)
(n,i)
(n,i)的大小有关,也就是说,我们只需要对于
n
n
n的所有因数
d
d
d,算出它的
f
d
^
\widehat{f_d}
fd
以及
ϕ
(
n
d
)
\phi(\frac{n}{d})
ϕ(dn)表示它的出现次数,就刻意知道它的答案了。
n
n
n的因数肯定不会太多,但是要怎么求出这些因数呢?
一种方法是将
n
n
n质因数分解后暴力计算。
由于这里的
n
n
n比较大,我们需要利用
Pollard_rho
\text{Pollard\_rho}
Pollard_rho算法快速分解。
时间复杂度 O ( n 1 4 + d ( n ) log m ) O\left(n^{\frac{1}{4}}+d(n)\log m\right) O(n41+d(n)logm),其中 d ( n ) d(n) d(n)表示 n n n的因子个数。
源码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef __int128 Li;
typedef pair<int,int> pii;
#define MAXN 1000005
#define pb push_back
#define mkpr make_pair
#define fir first
#define sec second
#define lson (rt<<1)
#define rson (rt<<1|1)
const int INF=0x3f3f3f3f;
const int mo=998244353;
template<typename _T>
void read(_T &x){
_T f=1;x=0;char s=getchar();
while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();}
while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+(s^48);s=getchar();}
x*=f;
}
template<typename _T>
_T Fabs(_T x){return x<0?-x:x;}
LL gcd(LL a,LL b){return !b?a:gcd(b,a%b);}
int add(int x,int y,int p){return x+y<p?x+y:x+y-p;}
void Add(int &x,int y,int p){x=add(x,y,p);}
int qkpow(int a,int s,int p){int t=1;while(s){if(s&1)t=1ll*a*t%p;a=1ll*a*a%p;s>>=1;}return t;}
LL qkpowll(LL a,LL s,LL p){LL t=1;while(s){if(s&1)t=(Li)a*t%p;a=(Li)a*a%p;s>>=1;}return t;}
mt19937 e(time(0));
bool MillerRabin(const LL &x){
if(x==2||x==3||x==5||x==7||x==11||x==13||x==17)return 1;
if(!(x%2)||!(x%3)||!(x%5)||!(x%7)||!(x%11)||!(x%13)||!(x%17))return 0;
LL k=0,r=x-1;uniform_int_distribution<LL> gx(2,x-2);
while(!(r&1))r>>=1,k++;
for(int i=0;i<30;i++){
bool flag=0;LL a=qkpowll(gx(e),r,x);
if(a==1||a==x-1)continue;
for(int j=1;j<=k;j++){
a=(Li)a*a%x;
if(a==x-1){flag=1;break;}
if(a==1)return 0;
}
if(!flag)return 0;
}
return 1;
}
LL PollardRho(const LL &n){
uniform_int_distribution<LL> gx(1,n-1);
while(1){
int times=0;LL d=1,x,y,C;x=y=gx(e);C=gx(e);
for(int stp=1;;stp<<=1,x=y){
bool cir=0;
for(int i=1;i<stp;i++){
y=((Li)y*y+C)%n;d=(Li)d*Fabs(x-y)%n;times++;
if(x==y||d==0){cir=1;break;}
if(times==127){d=gcd(d,n);if(d>1)return d;times=0;}
}
if(cir)break;d=gcd(d,n);if(d>1)return d;times=0;
}
}
return -1;
}
LL sta[55],g[55][65];int stak,ct[55],ans;
void work(LL n){
if(n==1)return ;if(MillerRabin(n)){sta[++stak]=n;return ;}
LL d=PollardRho(n);while(n%d==0)n/=d;work(d);work(n);
}
int T;LL n,m;
void dfs(int id,LL sum,LL phi){
if(id==stak+1){
int t=2*((n/sum)%2);
t=qkpow(qkpow(t,sum%(mo-1),mo),(m+m)%(mo-1),mo);
Add(ans,1ll*phi%mo*t%mo,mo);
return ;
}
for(int i=0;i<=ct[id];i++){
dfs(id+1,sum,phi*g[id][ct[id]-i]);
if(i<ct[id])sum*=sta[id];
}
}
int main(){
//freopen("ntt.in","r",stdin);
//freopen("ntt.out","w",stdout);
read(T);
while(T--){
read(n);read(m);work(n);LL now=n;
for(int i=1;i<=stak;i++){
g[i][0]=1;g[i][1]=sta[i]-1;ct[i]=0;
while(now%sta[i]==0)now/=sta[i],ct[i]++;
for(int j=2;j<=ct[i];j++)
g[i][j]=g[i][j-1]*sta[i];
}
dfs(1,1,1);ans=1ll*qkpow(n%mo,mo-2,mo)*ans%mo;
printf("%d\n",ans);ans=stak=0;
}
return 0;
}