E. Modular Stability
time limit per test
2 seconds
memory limit per test
512 megabytes
input
standard input
output
standard output
We define xmodyxmody as the remainder of division of xx by yy (%% operator in C++ or Java, mod operator in Pascal).
Let's call an array of positive integers [a1,a2,…,ak][a1,a2,…,ak] stable if for every permutation pp of integers from 11 to kk, and for every non-negative integer xx, the following condition is met:
(((xmoda1)moda2)…modak−1)modak=(((xmodap1)modap2)…modapk−1)modapk(((xmoda1)moda2)…modak−1)modak=(((xmodap1)modap2)…modapk−1)modapk
That is, for each non-negative integer xx, the value of (((xmoda1)moda2)…modak−1)modak(((xmoda1)moda2)…modak−1)modak does not change if we reorder the elements of the array aa.
For two given integers nn and kk, calculate the number of stable arrays [a1,a2,…,ak][a1,a2,…,ak] such that 1≤a1<a2<⋯<ak≤n1≤a1<a2<⋯<ak≤n.
Input
The only line contains two integers nn and kk (1≤n,k≤5⋅1051≤n,k≤5⋅105).
Output
Print one integer — the number of stable arrays [a1,a2,…,ak][a1,a2,…,ak] such that 1≤a1<a2<⋯<ak≤n1≤a1<a2<⋯<ak≤n. Since the answer may be large, print it modulo 998244353998244353.
题意:给你n,k(1<=k<=n<=5e5),从1到n中选k个数组成一个严格递增序列,如果对任何正整数,依次模上这k个数,无论这k个数如何排列得到的答案都相同,那么称这个序列为好序列,求好序列的个数%998244353
思路:
直接打表就可以发现,一个好序列中所有的数都能整除第一个数(也就是最小的那个数),所以我们利用组合数,枚举最小的数i,然后从剩下的(n/i)-1个能整除i的数中随意挑k-1个即可。
即使不打表,也可以推理出来,感兴趣的可以试一试(不过直接打表还是香呢)。
代码:
#include<bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define rep(i,a,b) for(register int i=(a);i<=(b);i++)
#define dep(i,a,b) for(register int i=(a);i>=(b);i--)
using namespace std;
const int maxn=1000010;
const long long mod=998244353;
int n,m,q,k,flag,x,f,y,p;
long long ni[maxn];
long long a[maxn];
long long zuhe(int x,int y){ //组合数
return a[x]*ni[y]%mod*ni[x-y]%mod;
}
long long calc(long long x,long long y){
long long z=1;
while (y){
if (y&1)(z*=x)%=mod;
(x*=x)%=mod,y/=2;
}
return z;
}
int main()
{
ios::sync_with_stdio(false);
a[0]=1;
for (int i=1;i<maxn;i++)a[i]=a[i-1]*i%mod;
ni[maxn-1]=calc(a[maxn-1],mod-2);
for (int i=maxn-2;i>=0;i--)
ni[i]=ni[i+1]*(i+1)%mod;
cin>>n>>k;
ll ans=0;
rep(i,1,n){
ll t=n/i;
if(t<k) break;
ans=(ans+zuhe(t-1,k-1))%mod;
}
printf("%lld\n",ans);
return 0;
}