Description
给出两个长度为 n n n的序列 a 0 , . . . , a n − 1 a_0,...,a_{n-1} a0,...,an−1和 b 0 , . . . , b n − 1 b_0,...,b_{n-1} b0,...,bn−1,以此定义矩阵 A i , j = a i ⊙ j A_{i,j}=a_{i\odot j} Ai,j=ai⊙j,其中 ⊙ \odot ⊙为异或运算,求解矩阵同余方程 A x = b ( m o d p ) Ax=b(mod\ p) Ax=b(mod p)
Input
第一行一整数 n n n表示序列长度,之后输入两个长度为 n n n的序列 a i , b i a_i,b_i ai,bi
( 1 ≤ n ≤ 222144 , p = 1 0 9 + 7 , 0 ≤ a i , b i < p ) (1\le n\le 222144,p=10^9+7,0\le a_i,b_i<p) (1≤n≤222144,p=109+7,0≤ai,bi<p)
Output
保证有唯一解,输出 x 0 , . . . , x n − 1 x_0,...,x_{n-1} x0,...,xn−1,需保证 0 ≤ x i < p 0\le x_i<p 0≤xi<p
Sample Input
4
1 10 100 1000
1234 2143 3412 4321
Sample Output
4
3
2
1
Solution1
A i , j = a i ⊙ j A_{i,j}=a_{i\odot j} Ai,j=ai⊙j,原矩阵方程等价于方程组 ∑ j = 1 n a i ⊙ j x j = b i \sum\limits_{j=1}^n a_{i\odot j}x_j=b_i j=1∑nai⊙jxj=bi,也即 ∑ j ⊙ k = i a j x k = b i \sum\limits_{j\odot k=i}a_jx_k=b_i j⊙k=i∑ajxk=bi,这意味着 a ⊗ x = b a\otimes x=b a⊗x=b,其中 ⊙ \odot ⊙表示异或运算, ⊗ \otimes ⊗表示异或卷积,那么有 F W T ( a ) ⋅ F W T ( x ) = F W T ( b ) FWT(a)\cdot FWT(x)=FWT(b) FWT(a)⋅FWT(x)=FWT(b),其中 ⋅ \cdot ⋅表示点乘,故对 a , b a,b a,b序列做 F W T FWT FWT后得到 b i ⋅ a i − 1 b_i\cdot a_i^{-1} bi⋅ai−1序列再 U F W T UFWT UFWT即得到 x x x序列
Code1
#include<cstdio>
using namespace std;
typedef long long ll;
#define maxn 262144+5
#define mod 1000000007
#define inv2 500000004
int mul(int x,int y)
{
ll z=1ll*x*y;
return z-z/mod*mod;
}
int add(int x,int y)
{
x+=y;
if(x>=mod)x-=mod;
return x;
}
int Pow(int a,int b)
{
int ans=1;
while(b)
{
if(b&1)ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
int n,a[maxn],b[maxn],x[maxn];
void FWT(int *a,int n)
{
for(int d=1;d<n;d<<=1)
for(int i=0;i<n;i+=(d<<1))
for(int j=0;j<d;j++)
{
int x=a[i+j],y=a[i+j+d];
a[i+j]=add(x,y),a[i+j+d]=add(x,mod-y);
}
}
void UFWT(int *a,int n)
{
for(int d=1;d<n;d<<=1)
for(int i=0;i<n;i+=(d<<1))
for(int j=0;j<d;j++)
{
int x=a[i+j],y=a[i+j+d];
a[i+j]=mul(inv2,add(x,y)),a[i+j+d]=mul(inv2,add(x,mod-y));
}
}
int main()
{
scanf("%d",&n);
for(int i=0;i<n;i++)scanf("%d",&a[i]);
for(int i=0;i<n;i++)scanf("%d",&b[i]);
FWT(b,n),FWT(a,n);
for(int i=0;i<n;i++)x[i]=mul(b[i],Pow(a[i],mod-2));
UFWT(x,n);
for(int i=0;i<n;i++)printf("%d\n",x[i]);
return 0;
}
Solution2
考虑
i
⊙
j
i\odot j
i⊙j矩阵,由于
n
n
n为
2
2
2的整数次幂,简单分析可知
A
A
A矩阵可以分块用两个
n
2
⋅
n
2
\frac{n}{2}\cdot \frac{n}{2}
2n⋅2n的矩阵表示
A
=
[
A
1
A
2
A
2
A
1
]
A= \left[\begin{matrix} A^1\ A^2\\ A^2\ A^1 \end{matrix}\right]
A=[A1 A2A2 A1]
同时将
x
,
b
x,b
x,b分成前后两个等长的部分,
x
=
[
x
1
,
x
2
]
T
,
b
=
[
b
1
,
b
2
]
T
x=[x^1,x^2]^T,b=[b^1,b^2]^T
x=[x1,x2]T,b=[b1,b2]T,那么原方程可以表示为
A
1
x
1
+
A
2
x
2
=
b
1
,
A
2
x
1
+
A
1
x
2
=
b
2
A^1x^1+A^2x^2=b^1,A^2x^1+A^1x^2=b^2
A1x1+A2x2=b1,A2x1+A1x2=b2
两式相加/做差得
(
A
1
+
A
2
)
(
x
1
+
x
2
)
=
b
1
+
b
2
,
(
A
1
−
A
2
)
(
x
1
−
x
2
)
=
(
b
1
−
b
2
)
(A^1+A^2)(x^1+x^2)=b^1+b^2,(A^1-A^2)(x^1-x^2)=(b^1-b^2)
(A1+A2)(x1+x2)=b1+b2,(A1−A2)(x1−x2)=(b1−b2)
令
a
i
1
=
a
i
+
a
i
+
n
2
,
x
i
1
=
x
i
+
x
i
+
n
2
,
b
i
1
=
b
i
+
b
i
+
n
2
,
a
i
2
=
a
i
−
a
i
+
n
2
,
x
i
2
=
x
i
−
x
i
+
n
2
,
b
i
2
=
b
i
−
b
i
−
n
2
a^1_{i}=a_i+a_{i+\frac{n}{2}},x^1_i=x_i+x_{i+\frac{n}{2}},b^1_i=b_i+b_{i+\frac{n}{2}},a^2_{i}=a_i-a_{i+\frac{n}{2}},x^2_i=x_i-x_{i+\frac{n}{2}},b^2_i=b_i-b_{i-\frac{n}{2}}
ai1=ai+ai+2n,xi1=xi+xi+2n,bi1=bi+bi+2n,ai2=ai−ai+2n,xi2=xi−xi+2n,bi2=bi−bi−2n,
那么简单分析可知
a
1
a^1
a1序列可以用同样的方式生成矩阵
B
1
=
A
1
+
A
2
B^1=A^1+A^2
B1=A1+A2,
a
2
a^2
a2也可以生成
B
2
=
A
1
−
A
2
B^2=A^1-A^2
B2=A1−A2,故求解原来的
n
×
n
n\times n
n×n规模的矩阵方程问题转化为两个
n
2
×
n
2
\frac{n}{2}\times \frac{n}{2}
2n×2n规模的矩阵方程问题
B
1
x
1
=
b
1
,
B
2
x
2
=
b
2
B^1x^1=b^1,B^2x^2=b^2
B1x1=b1,B2x2=b2
递归下去,以一般方程作为递归终点即可,时间复杂度
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn),注意到这种解法其实就是模拟了
F
W
T
FWT
FWT的过程
Code2
#include<cstdio>
using namespace std;
typedef long long ll;
#define maxn 262144+5
#define mod 1000000007
#define inv2 500000004
int n,a[19][maxn],b[19][maxn],x[19][maxn],y[19][maxn];
int mul(int x,int y)
{
ll z=1ll*x*y;
return z-z/mod*mod;
}
int add(int x,int y)
{
x+=y;
if(x>=mod)x-=mod;
return x;
}
int Pow(int a,int b)
{
int ans=1;
while(b)
{
if(b&1)ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
void Solve(int n,int m)
{
if(m==0)
{
x[0][0]=mul(b[0][0],Pow(a[0][0],mod-2));
return ;
}
for(int i=0;i<n/2;i++)
{
a[m-1][i]=add(a[m][i],a[m][i+n/2]);
b[m-1][i]=add(b[m][i],b[m][i+n/2]);
}
Solve(n/2,m-1);
for(int i=0;i<n/2;i++)y[m][i]=x[m-1][i];//x[i]+x[i+n/2]
for(int i=0;i<n/2;i++)
{
a[m-1][i]=add(a[m][i],mod-a[m][i+n/2]);
b[m-1][i]=add(b[m][i],mod-b[m][i+n/2]);
}
Solve(n/2,m-1);
for(int i=0;i<n/2;i++)y[m][i+n/2]=x[m-1][i];//x[i]-x[i+n/2]
for(int i=0;i<n/2;i++)
{
x[m][i]=mul(inv2,add(y[m][i],y[m][i+n/2]));
x[m][i+n/2]=mul(inv2,add(y[m][i],mod-y[m][i+n/2]));
}
}
int main()
{
scanf("%d",&n);
int m=0;
while((1<<m)<n)m++;
for(int i=0;i<n;i++)scanf("%d",&a[m][i]);
for(int i=0;i<n;i++)scanf("%d",&b[m][i]);
Solve(n,m);
for(int i=0;i<n;i++)printf("%d\n",x[m][i]);
return 0;
}