翻译
思路
最麻烦的是 m ≤ 1 e 9 m \le1e9 m≤1e9,甚至 O ( n m ) O(n \sqrt{m}) O(nm)都不行,线性筛也 p a s s pass pass。
先简化一下题意,由于 g c d gcd gcd具有传递性,即:
g c d ( a , b , c ) = g c d ( g c d ( a , b ) , c ) gcd(a,b,c)=gcd(gcd(a,b),c) gcd(a,b,c)=gcd(gcd(a,b),c)
所以只需要每次找到 b i b_{i} bi的数量,满足:
g c d ( b i , a i − 1 ) = a i gcd(b_{i}, a_{i-1})=a_{i} gcd(bi,ai−1)=ai
显然,所有 a i , i = 2 , 3 , . . . , n a_{i},i=2,3,...,n ai,i=2,3,...,n,满足 a i ∣ a i − 1 a_{i} | a_{i-1} ai∣ai−1,否则直接输出 0 0 0。
在这个条件下,设 c i = b i / a i c_{i}=b_{i}/a_{i} ci=bi/ai,则问题转化为寻找 c i c_{i} ci的个数,满足:
g c d ( c i , a i − 1 / a i ) = 1 gcd(c_{i}, a_{i-1}/a_{i})=1 gcd(ci,ai−1/ai)=1
s . t . c i ≤ m / a i s.t. \ c_{i} \le m / a_{i} s.t. ci≤m/ai
哎,这不就是求互质数的个数嘛,欧拉函数就好了呀
显然没有这么简单,有很多问题要解决:首先欧拉函数是求
1
1
1到
x
x
x中与
x
x
x互质的数的个数,但这里边界
R
R
R和
x
x
x可以不相等;其次,直接处理所有质数是不行的,因为最快的线性筛也要
O
(
m
)
O(m)
O(m)。
再想想欧拉函数,化简后的形式用不了了,但是欧拉函数还有个原始的式子(容斥原理):
这时发现遇到的所有有关的质数都一定会出现在
a
1
a_{1}
a1中,因为
a
i
∣
a
1
a_{i}|a_{1}
ai∣a1对所有
i
i
i均成立,所以只需要
O
(
m
)
O(\sqrt{m})
O(m)遍历寻找
m
m
m的因数
p
j
p_{j}
pj,再用
O
(
p
j
)
O(\sqrt{p_{j}})
O(pj)判断一下是不是质数即可。而这些质数的数量是很有限的(小于
O
(
l
o
g
a
1
)
O(loga_{1})
O(loga1))
需要注意的是,虽然
p
j
p_{j}
pj的最大可能达到
m
m
m,但总复杂度远远达不到
O
(
m
)
O(m)
O(m),具体怎么算我也不知道 。
然后,我们就可以用二进制来表示当前的 a i a_{i} ai拥有哪些质数。然后用 l o w b i t lowbit lowbit来遍历这些质数,进行容斥原理的计算即可。
代码
// gcd(b[i], a[i-1]) = a[i]
// b[i] = a[i] * c, (c, a[i-1] / a[i]) = 1
#include<bits/stdc++.h>
using namespace std;
#define N 200005
#define int long long
int t, n, m, a[N], ans[N], mod = 998244353;
int P[N], cnt;
int p[N], ln;
inline int lowbit(int x) {
return x & (-x);
}
inline int gcd(int x, int y) {
if(y == 0) return x;
return gcd(y, x % y);
}
inline bool is_p(int x) {
if(x == 0 || x == 1) return false;
for(int i=2;i*i<=x;i++)
if(x % i == 0) return false;
return true;
}
inline int f(int R, int x) {
// 计算1 ~ R内与x互质的数的个数
// printf("1~%lld, x: %lld\n", R, x);
ln = 0;
for(int i=1;i<=cnt;i++) {
if(x % P[i] == 0) p[ln++] = P[i];
}
// for(int i=0;i<ln;i++) cout<<p[i]<<' '; cout<<endl;
int ANS = 0;
for(int i=0;i<(1ll<<ln);i++) {
int now = 1, num = 1;
int j = i;
while(j) {
int d = lowbit(j);
int id = log2(d);
now *= p[id];
num *= -1;
j -= d;
}
ANS += num * (R / now);
ANS = (ANS % mod + mod) % mod;
}
// printf("ANS: %lld\n", ANS);
return ANS;
}
signed main() {
cin>>t;
while(t--) {
cin>>n>>m;
for(int i=1;i<=n;i++) scanf("%lld", a+i);
cnt = 0;
for(int i=1;i * i<=a[1];i++) {
if(a[1] % i != 0) continue;
if(is_p(i)) P[++cnt] = i;
if(i * i != a[1] && is_p(a[1]/i)) P[++cnt] = a[1] / i;
}
// for(int i=1;i<=cnt;i++) cout<<P[i]<<' '; cout<<endl;
ans[1] = 1;
for(int i=2;i<=n;i++) {
if(a[i-1] % a[i] != 0) {
ans[i] = 0;
break;
}
if(a[i] == a[i-1]) ans[i] = m / a[i-1];
else ans[i] = f(m / a[i], a[i-1] / a[i]);
}
int ANS = 1;
for(int i=1;i<=n;i++) {
ANS = ANS * ans[i] % mod;
}
cout<<ANS<<endl;
}
return 0;
}