Description
给定一个 n × m n \times m n×m 的矩阵,每一行最多选一个数,每一列可以选若干个数,但是每一列选的数不能超总数的一半。求有多少个不同的方案数。
Solution
容斥 + 计数 dp。 a n s = ans = ans= 全部的方案数 − - − 超过 ⌊ k 2 ⌋ \lfloor \frac{k}{2 }\rfloor ⌊2k⌋ 的方案数。
令
s
u
m
i
sum_i
sumi 为
s
u
m
i
=
∑
j
=
1
m
a
i
,
j
sum_i = \sum_{j=1}^m a_{i,j}
sumi=j=1∑mai,j
全部的方案数即为
∏
i
=
1
n
(
s
u
m
i
+
1
)
−
1
\prod_{i=1}^n (sum_i + 1) - 1
i=1∏n(sumi+1)−1
用计数 dp 求超过
⌊
k
2
⌋
\lfloor \frac{k}{2 }\rfloor
⌊2k⌋ 的方案数。枚举一个列,令
f
i
,
j
,
k
f_{i,j,k}
fi,j,k 为在前
i
i
i 行中,枚举的列选了
j
j
j 个,其他列选了
k
k
k 个的方案数。转移为
f
i
,
j
,
k
=
f
i
−
1
,
j
,
k
+
f
i
−
1
,
j
−
1
,
k
×
a
i
,
j
+
f
i
−
1
,
j
,
k
−
1
×
(
s
u
m
i
−
a
i
,
j
)
f_{i,j,k} = f_{i - 1, j, k} + f_{i - 1, j - 1, k} \times a_{i,j} + f_{i - 1, j, k - 1} \times (sum_i - a_{i,j})
fi,j,k=fi−1,j,k+fi−1,j−1,k×ai,j+fi−1,j,k−1×(sumi−ai,j)
超过
⌊
k
2
⌋
\lfloor \frac{k}{2 }\rfloor
⌊2k⌋ 的方案数为
∑
j
>
k
f
n
,
j
,
k
\sum_{j>k} f_{n,j,k}
j>k∑fn,j,k
如何优化?
j
,
k
j,k
j,k 两维换成
j
−
k
j - k
j−k。为什么?每一层转移中
j
−
k
j - k
j−k 不相同 ,统计答案时只计算了
j
−
k
>
0
j - k > 0
j−k>0 的部分。
枚举一个列,令
f
i
,
k
f_{i,k}
fi,k 为在前
i
i
i 行中,枚举的列选了比其他列多选了
k
k
k 个的方案数。转移为
f
i
,
k
=
f
i
−
1
,
k
+
f
i
−
1
,
k
−
1
×
a
i
,
k
+
f
i
−
1
,
k
+
1
×
(
s
u
m
i
−
a
i
,
k
)
f_{i,k} = f_{i-1,k} + f_{i-1,k-1} \times a_{i,k} + f_{i-1,k+1} \times (sum_i - a_{i,k})
fi,k=fi−1,k+fi−1,k−1×ai,k+fi−1,k+1×(sumi−ai,k)
综上所述
a
n
s
=
∏
i
=
1
n
(
s
u
m
i
+
1
)
−
1
−
∑
i
=
1
n
f
n
,
i
ans = \prod_{i=1}^n (sum_i + 1) - 1 - \sum_{i=1}^n f_{n,i}
ans=i=1∏n(sumi+1)−1−i=1∑nfn,i
细节问题如初值,对负值增加一个偏移量,详见代码。
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 100 + 5, M = 2000 + 5, p = 998244353;
ll all = 1, ans, f[N][N + N], a[N][M], sum[N];
int n, m;
int main(){
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++)
scanf("%lld", &a[i][j]), sum[i] = (sum[i] + a[i][j]) % p;
all = all * (sum[i] + 1) % p;
}
for (int j = 1; j <= m; j++) {
f[0][N] = 1;
for (int i = 1; i <= n; i++)
for (int k = -n + N; k <= n + N; k++)
f[i][k] = (f[i - 1][k] + f[i - 1][k - 1] * a[i][j] % p + f[i - 1][k + 1] * (sum[i] - a[i][j]) % p) % p;
for (int i = 1; i <= n; i++) ans = (ans + f[n][i + N]) % p;
}
printf("%lld\n", (all - ans - 1 + p) % p);
return 0;
}