杜教筛
参考了这里
# include <bits/stdc++.h>
using namespace std;
using ll = long long ;
const int mod = 998244353 ;
const int N = ( int ) 1e6 + 6 ;
int n, m;
vector< int > k;
void input ( ) {
cin >> n >> m;
k = vector< int > ( m) ;
for ( auto & i : k) {
cin >> i;
}
}
ll F ( int x) {
ll ans = 0 ;
for ( int i = 1 ; i < ( 1 << m) ; ++ i) {
int z = 1 , ji = 0 ;
for ( int j = 0 ; j < m; ++ j) {
if ( i & ( 1 << j) ) {
++ ji;
z = z * k[ j] / __gcd ( z, k[ j] ) ;
}
}
if ( ji & 1 ) {
ans += x / z;
} else {
ans -= x / z;
}
}
return ans;
}
vector< ll> s;
int t;
void pre_task ( ) {
t = min ( n, N) ;
s = vector< ll> ( t + 1 , 0 ) ;
vector< int > is ( t + 1 , 0 ) ;
vector< ll> f ( t + 1 , 0 ) ;
for ( int i = 0 ; i < m; ++ i) {
for ( int j = k[ i] ; j <= t; j += k[ i] ) {
is[ j] = 1 ;
}
}
for ( int i = 2 ; i <= t; ++ i) {
f[ i] <<= 1 ; f[ i] += is[ i] == 1 ; f[ i] %= mod;
int y = sqrt ( i) ; if ( y * y == i) f[ i] += f[ y] * f[ y] ; f[ i] %= mod;
for ( int j = 1 ; j < i && i * j <= t; ++ j) {
f[ i * j] += f[ i] * f[ j] ; f[ i * j] %= mod;
}
s[ i] = s[ i - 1 ] + f[ i] ;
s[ i] %= mod;
}
}
map< int , int > mp;
ll sum ( int x) {
if ( x <= t) {
return s[ x] ;
}
if ( mp[ x] ) {
return mp[ x] ;
}
ll re = F ( x) ;
for ( int r, l = 2 ; l <= x; l = r + 1 ) {
if ( ( r = x / ( x / l) ) >= x) {
break ;
}
re += ( sum ( r) - sum ( l - 1 ) ) * sum ( x / l) % mod;
re = ( re + mod) % mod;
}
return mp[ x] = re;
}
signed main ( ) {
ios:: sync_with_stdio ( false ) ;
cin. tie ( 0 ) ;
int T;
cin >> T;
while ( T-- ) {
input ( ) ;
pre_task ( ) ;
mp. clear ( ) ;
cout << sum ( n) << '\n' ;
}
return 0 ;
}