题意:
你有
n
n
n个灯,每个灯有四种可能的颜色,一开始都是第一种颜色,有
m
m
m种操作,每种操作是一个
x
i
x_i
xi,表示把
x
i
x_i
xi的倍数全都变成下一个颜色,第四次变化后会变回第一次的颜色,问你从这
m
m
m种操作中随机选出一个集合的操作去进行,期望有多少个颜色为初始的第一种颜色的灯。
n
<
=
1
e
9
,
m
<
=
20
n<=1e9,m<=20
n<=1e9,m<=20,答案对998244353取模。
题解:
首先感觉这个题可能会
2
m
2^m
2m枚举集合,于是这样我们可以转化成求总方案数,最后再除以一个
2
m
2^m
2m就可以了。我们可能会跟着直觉往dp或者容斥方向去想,但是这个题确实不好想。先%%%y_immortal大佬,他向我推荐的这个题,并且写了全网第一篇题解,我也是跟他学的。
首先我们应该可以想到,一个灯最后还是初始颜色的条件是它被变化颜色的次数是4的倍数。我们发现一个数字它被变化多次的条件是它是两个操作的lcm的倍数,那么我们先设 f [ S ] f[S] f[S]表示 [ 1 , n ] [1,n] [1,n]中有多少个数是集合 S S S的公倍数,但是我们会发现,这样考虑是会有重复的,就是小集合里的数的公倍数可能会在它的超集里再次被计算。所以我们想定义一个 g [ S ] g[S] g[S],表示在 [ 1 , n ] [1,n] [1,n]中有多少个数是 S S S中元素的lcm的倍数,并且不是任何一个超集的lcm的倍数的个数。 g [ S ] g[S] g[S]需要用容斥去算。
但是就算是算出来了刚才那些,还是会复杂度爆炸啊,于是这个题正解的做法是考虑把元素个数相同的 f f f和 g g g一起算出来。对于 f f f比较好算,再原来的基础上算的时候记录一下当前集合有多少个元素,然后加到对应元素数的地方就行了。重点是求这个 g g g。首先,我们先把 g g g的初始值设为 f f f的初始值,然后考虑容斥,容斥的思路还是减去超集中的答案。我们用一个 m 2 m^2 m2的复杂度来计算每一个 g [ i ] g[i] g[i]的答案,计算的方法是对于当前的 i i i,枚举所有比它大的集合,考虑从当前的答案中减去是当前集合lcm的倍数同时大集合lcm的倍数的数,过程中要乘一个组合数。求出 g g g之后就可以计算答案了,答案就是 ∑ i = 0 m g [ i ] ∗ ∑ j = 0 & j % 4 = 0 m C i j ∗ 2 m − i \sum_{i=0}^mg[i]*\sum_{j=0\&j\%4=0}^mC_{i}^j*2^{m-i} ∑i=0mg[i]∗∑j=0&j%4=0mCij∗2m−i。原因是你在一个 i i i个操作里选任意4的倍数个操作都会是合法的。
感觉还是很神仙的一道题。
代码:
#include <bits/stdc++.h>
using namespace std;
int n,m,a[30];
long long f[30],g[30],c[51][51],ans;
const long long mod=998244353;
inline long long ksm(long long x,long long y)
{
long long res=1;
while(y)
{
if(y&1)
res=res*x%mod;
x=x*x%mod;
y>>=1;
}
return res;
}
inline long long gcd(long long x,long long y)
{
return y?gcd(y,x%y):x;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=m;++i)
scanf("%d",&a[i]);
for(int i=0;i<=50;++i)
c[i][0]=1;
for(int i=1;i<=50;++i)
{
for(int j=1;j<=i;++j)
c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
}
int mx=(1<<m);
for(int i=0;i<mx;++i)
{
int ji=0;
long long lcm=1;
for(int j=1;j<=m;++j)
{
if(i&(1<<(j-1)))
{
lcm=lcm*a[j]/gcd(lcm,a[j]);
if(lcm>n)
break;
++ji;
}
}
if(lcm>n)
continue;
f[ji]=(f[ji]+n/lcm)%mod;
}
for(int i=0;i<=m;++i)
g[i]=f[i];
for(int i=m;i>=0;--i)
{
for(int j=i+1;j<=m;++j)
g[i]=(g[i]-g[j]*c[j][i]%mod+mod)%mod;
}
for(int i=0;i<=m;++i)
ans=(ans+g[i]*((c[i][0]+c[i][4]+c[i][8]+c[i][12]+c[i][16]+c[i][20])%mod)%mod*ksm(2,m-i)%mod)%mod;
ans=ans*ksm(ksm(2,m),mod-2)%mod;
printf("%lld\n",ans);
return 0;
}