算法标签
LGV引理、NTT
思路
因为要求不相交的方案数,所以容易想到LGV引理
矩阵形式 M ( i , j ) = E ( a i , b j ) M(i,j)=E(a_i,b_j) M(i,j)=E(ai,bj),代表的是 a i a_i ai到 b j b_j bj 可走的路径的权值之积的和,计数问题权值位1,故 M ( i , j ) M(i,j) M(i,j)代表从 a i a_i ai到 b j b_j bj的方案数
从 ( 0 , a i ) (0,a_i) (0,ai)到 ( i , 0 ) (i,0) (i,0)横坐标要走 i i i个步,纵坐标要走 a i a_i ai步
所以移动到 A ( i , j ) A(i,j) A(i,j)方案数为 ( a i + j j ) \dbinom{a_i+j}{j} (jai+j)
故问题最终要求解如下行列式:
A
=
[
(
a
1
+
1
1
)
(
a
1
+
2
2
)
⋯
(
a
1
+
n
n
)
⋮
⋱
⋮
(
a
n
+
1
1
)
(
a
n
+
2
2
)
⋯
(
a
n
+
n
n
)
]
A=\begin{bmatrix} \dbinom{a_1+1}{1} &\dbinom{a_1+2}{2} & \cdots & \dbinom{a_1+n}{n} \\ \vdots & \ddots & \vdots \\ \dbinom{a_n+1}{1}& \dbinom{a_n+2}{2} & \cdots & \dbinom{a_n+n}{n} \end{bmatrix}
A=⎣⎢⎢⎢⎢⎢⎡(1a1+1)⋮(1an+1)(2a1+2)⋱(2an+2)⋯⋮⋯(na1+n)(nan+n)⎦⎥⎥⎥⎥⎥⎤
利用线性代数知识,对上矩阵化简,因为
(
a
i
+
j
j
)
=
(
a
i
+
j
)
!
j
!
a
i
!
\dbinom{a_i+j}{j}=\frac{(a_i+j)!}{j! a_i!}
(jai+j)=j!ai!(ai+j)!
故,每一列提取
1
j
!
\frac{1}{j!}
j!1后可以让每一行都变长
a
i
a_i
ai的
j
j
j次多项式,即
A
′
(
i
,
j
)
=
(
a
i
+
j
)
!
a
i
!
=
∏
k
=
1
j
(
a
i
+
k
)
A'(i,j)=\frac{(a_i+j)!}{a_i!}=\prod_{k=1}^j(a_i+k)
A′(i,j)=ai!(ai+j)!=k=1∏j(ai+k)
对于每一列来说
j
j
j式一个常数,每一列都变成一样的了,利用行列式初等变化,把每一行都变成
(
a
i
+
1
)
j
(a_i+1)^j
(ai+1)j ,例如
[
(
a
1
+
1
)
(
a
1
+
1
)
(
a
1
+
2
)
(
a
2
+
1
)
(
a
2
+
1
)
(
a
2
+
2
)
]
=
c
2
−
c
1
[
(
a
1
+
1
)
(
a
1
+
1
)
2
(
a
2
+
1
)
(
a
2
+
1
)
2
]
\begin{bmatrix} (a_1+1) &(a_1+1)(a_1+2)\\ (a_2+1) &(a_2+1)(a_2+2) \end{bmatrix}\overset{c2-c1}= \begin{bmatrix} (a_1+1) &(a_1+1)^2\\ (a_2+1) &(a_2+1)^2 \end{bmatrix}
[(a1+1)(a2+1)(a1+1)(a1+2)(a2+1)(a2+2)]=c2−c1[(a1+1)(a2+1)(a1+1)2(a2+1)2]
化简后可得
A
i
,
j
′
′
=
(
a
i
+
1
)
j
A''_{i,j}=(a_i+1)^j
Ai,j′′=(ai+1)j
我们会得到,这是一个范德蒙德行列式,所以
A
′
′
=
∏
(
a
i
+
1
)
⋅
∏
1
≤
i
<
j
≤
n
(
(
a
j
+
1
)
−
(
a
i
+
1
)
)
A''=\prod(a_i+1)\cdot\prod_{1\leq i<j\leq n}((a_j+1)-(a_i+1))
A′′=∏(ai+1)⋅1≤i<j≤n∏((aj+1)−(ai+1))
故
A
=
∏
j
=
1
n
1
j
!
⋅
∏
(
a
i
+
1
)
⋅
∏
1
≤
i
<
j
≤
n
(
(
a
j
+
1
)
−
(
a
i
+
1
)
)
A=\prod_{j=1}^n\frac{1}{j!}\cdot\prod(a_i+1)\cdot\prod_{1\leq i<j\leq n}((a_j+1)-(a_i+1))
A=j=1∏nj!1⋅∏(ai+1)⋅1≤i<j≤n∏((aj+1)−(ai+1))
最后使用卷积加速运算
代码
#include <bits/stdc++.h>
#define closeSync ios::sync_with_stdio(false);cin.tie(0);cout.tie(0)
using namespace std;
typedef long long ll;
const int MAXN = 1e7 + 10; // n + m << 1
const double PI = acos(-1);
const int D = 1000000;
const ll MOD = 998244353;
const int INF = 0x3f3f3f3f;
struct Complex { // 复数
double x, y;
Complex(double xx = 0, double yy = 0) {
x = xx, y = yy;
}
Complex operator+(const Complex &xx)const {
return Complex(x + xx.x, y + xx.y);
}
Complex operator-(const Complex &xx)const {
return Complex(x - xx.x, y - xx.y);
}
Complex operator*(const Complex &xx)const {
return Complex(x * xx.x - y * xx.y, x * xx.y + y * xx.x);
}
} A[MAXN], B[MAXN];
int n = 1000000, m = 1000000; // 多项式 A,B 的最高幂次
struct FFT {
int r[MAXN];
inline void fft(Complex *A, int lim, int type) {
// type = 1 : 傅利叶变换 ; type = -1 : 傅利叶逆变换
for (int i = 0; i < lim; i++)
if (i < r[i])
swap(A[i], A[r[i]]); // 初始状态
for (int i = 1; i < lim; i <<= 1) {
// 要被归并的区间长度
Complex wn(cos(PI / i), type * sin(PI / i));
for (int j = 0; j < lim; j += i << 1) {
// 一个区间一个区间的合并
Complex w(1, 0);
for (int k = 0; k <= i - 1; k++, w = w * wn) {
// w = w(2i,k)
Complex x = A[j + k], y = w * A[j + i + k];
A[j + k] = x + y;
A[j + i + k] = x - y;
}
}
}
return ;
}
inline vector<ll> poly_mul(Complex *A, Complex *B) {
int lim = 1, l = 0;
while (lim <= n + m)
lim <<= 1, l++;
for (int i = 0; i < lim; i++)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)); //二进制反转取得i位置的最终位置r[i]
fft(A, lim, 1);
fft(B, lim, 1);
for (int i = 0; i < lim; i++)
A[i] = A[i] * B[i]; // A[i] 满足 A,B[i] 满足 B,A[i] * B[i] 满足 A * B
fft(A, lim, -1); // 根据点值求多项式
vector<ll> ans;
for (int i = 0; i <= n + m; i++)
ans.push_back(ll(A[i].x / lim + 0.5));
return ans;
}
} fff;
inline ll qpower(ll x,ll p,ll mod)
{
ll res = 1;
while (p)
{
if (p & 1) res *= x,res %= mod;
x *= x,x %= mod;
p >>= 1;
}
return res;
}
int a[MAXN];
ll fac[MAXN],invfac[MAXN];
inline void init()
{
int n = D + 3;
fac[0] = 1;
for (int i=1;i<=n;i++)
{
fac[i] = fac[i-1] * i;
fac[i] %= MOD;
}
invfac[n] = qpower(fac[n],MOD-2,MOD);
// cout << invfac[n] << "\n";
for (int i=n-1;i>=0;i--)
{
invfac[i] = ( invfac[i+1] * (i + 1) % MOD ),invfac[i] %= MOD;
// cout << invfac[i] << "\n";
}
// cout << invfac[0] << "\n";
}
inline void solve()
{
int N; cin >> N;
int mx = -1,mn = INF;
for (int i=1;i<=N;i++)
{
cin >> a[i];
mx = max(a[i],mx);
mn = min(a[i],mn);
A[a[i]].x = 1;
B[-a[i] + D].x = 1;
}
vector<ll> ans = fff.poly_mul(A,B);
// cout << "112\n";
// cout << ans[-2+D] << "\n";
// n + m
ll res = 1;
for (int i=1;i<=N;i++)
{
res *= a[i] + 1,res %= MOD;
// cout << res << "\n";
}
// cout << res << "\n";
for (int i=1;i<=N;i++)
{
// cout << invfac[i] << "\n";
res *= invfac[i],res %= MOD;
// cout << res << "\n";
}
// cout << res << "\n";
for (int i=1+D;i<=mx-mn+D;i++)
res *= qpower(i - D,ans[i],MOD),res %= MOD;
cout << res << "\n";
}
int main()
{//closeSync;
init();
// int T; cin >> T;
// while (T--)
solve();
return 0;
}
最终时间复杂度 O ( n l o g n ) O(nlogn) O(nlogn)