[AGC034F] RNG and XOR
题解
比较有意思的一道题。
设 f ( i ) f(i) f(i) 表示第一次变成 i i i 的期望次数,首先肯定有 f ( 0 ) = 0 f(0)=0 f(0)=0,然后不妨枚举从 0 开始走的第一步,那么有 f ( i ) = 1 + ∑ j = 0 2 n − 1 P ( j ) f ( i ⨁ j ) f(i)=1+\sum_{j=0}^{2^n-1}P(j)f(i\bigoplus j) f(i)=1+j=0∑2n−1P(j)f(i⨁j)( P ( j ) P(j) P(j) 表示数异或上 j j j 的概率)我们发现这玩意长得像异或卷积,但又不完全相同。既然长得像,那么我们不妨先把它们转换成点值看看:
设
x
i
x_i
xi 表示
f
(
x
)
f(x)
f(x) 做了沃尔什变换后
i
i
i 处的点值,
p
i
p_i
pi 表示
P
(
x
)
P(x)
P(x) 变换后
i
i
i 处的点值,
y
i
y_i
yi 表示函数
g
(
x
)
=
1
g(x)=1
g(x)=1 在
i
i
i 处的点值,那么只看同一下标,可以得到关系式:
x
i
×
p
i
+
y
i
+
z
=
x
i
x_i\times p_i+y_i+z=x_i\\
xi×pi+yi+z=xi其中
z
z
z 是某个未知常数。为什么有这个常数,是因为 0 处的
f
f
f 值不满足最上面的那个式子,所以有偏差,而这个偏差的表现就是点值的上下平移。
观察发现 p i p_i pi 只有 0 处是 1,而其他地方都不为 1,因此可以得到 z = − y 0 z=-y_0 z=−y0。由于其它地方的 p p p 不为 1,所以知道了 z z z 过后可以直接解出非 0 处的 x i x_i xi。
虽然剩下一个 x 0 x_0 x0 未知,但是我们知道 f ( 0 ) = 1 2 n ∑ i = 0 2 n − 1 x i = 0 f(0)=\frac{1}{2^n}\sum_{i=0}^{2^n-1}x_i=0 f(0)=2n1∑i=02n−1xi=0(沃尔什逆变换),所以推出 x 0 = − ∑ i = 1 2 n − 1 x i x_0=-\sum_{i=1}^{2^n-1}x_i x0=−∑i=12n−1xi,于是我们巧妙求得了 f ( x ) f(x) f(x) 每一处的点值。
剩下只需要把点值再转换回来即可。
代码
#include<bits/stdc++.h>//JZM yyds!!
#define ll long long
#define lll __int128
#define uns unsigned
#define fi first
#define se second
#define IF (it->fi)
#define IS (it->se)
#define END putchar('\n')
#define lowbit(x) ((x)&-(x))
#define inline jzmyyds
using namespace std;
const int MAXN=1<<18;
const ll INF=1e18;
ll read(){
ll x=0;bool f=1;char s=getchar();
while((s<'0'||s>'9')&&s>0){if(s=='-')f^=1;s=getchar();}
while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+(s^48),s=getchar();
return f?x:-x;
}
int ptf[50],lpt;
void print(ll x,char c='\n'){
if(x<0)putchar('-'),x=-x;
ptf[lpt=1]=x%10;
while(x>9)x/=10,ptf[++lpt]=x%10;
while(lpt>0)putchar(ptf[lpt--]^48);
if(c>0)putchar(c);
}
const ll MOD=998244353;
ll ksm(ll a,ll b,ll mo){
ll res=1;
for(;b;b>>=1,a=a*a%mo)if(b&1)res=res*a%mo;
return res;
}
int n;
ll y[MAXN],P[MAXN],x[MAXN],sum,z;
void FWTXOR(ll*a,int inv){
const ll cg=inv>0?1:((MOD+1)>>1);ll x,y;
for(int m=1;m<n;m<<=1)
for(int i=0;i<n;i+=(m<<1))
for(int j=i;j<i+m;j++)
x=a[j],y=a[j+m],a[j]=(x+y)*cg%MOD,a[j+m]=(x-y+MOD)*cg%MOD;
}
int main()
{
n=1<<read();
for(int i=0;i<n;i++)P[i]=read(),sum+=P[i],y[i]=1;
sum=ksm(sum,MOD-2,MOD);
for(int i=0;i<n;i++)(P[i]*=sum)%=MOD;
FWTXOR(P,1),FWTXOR(y,1),z=(MOD-y[0])%MOD;
for(int i=1;i<n;i++)
x[i]=(y[i]+z)*ksm(MOD+1-P[i],MOD-2,MOD)%MOD,x[0]+=MOD-x[i];
x[0]%=MOD,FWTXOR(x,-1);
for(int i=0;i<n;i++)print(x[i]);
return 0;
}