题面
样例
2 3
1 0 1
0 1 1
3
分析
考虑正着限制最大的数不超过一半不好做,那我们可以反着来。
令 d p [ i ] [ j ] [ k ] dp[i][j][k] dp[i][j][k] 为第 i i i 行,此列指定数 p o i n t point point 数量为 j j j 个,选的非指定数的个数为 k k k 的方案数。
则一共有三种情况
C a s e 1 : Case\ 1: Case 1: d p [ i ] [ j ] [ k ] = d p [ i − 1 ] [ j ] [ k ] dp[i][j][k] = dp[i - 1][j][k] dp[i][j][k]=dp[i−1][j][k] (此列不选数字)
C a s e 2 : d p [ i ] [ j ] [ k ] = d p [ i ] [ j ] [ k ] + d p [ i − 1 ] [ j ] [ k − 1 ] ∗ ( p r e [ i ] − a [ i ] [ p o i n t ] ) Case\ 2:dp[i][j][k] = dp[i][j][k] + dp[i - 1][j][k - 1] * (pre[i] - a[i][point]) Case 2:dp[i][j][k]=dp[i][j][k]+dp[i−1][j][k−1]∗(pre[i]−a[i][point])(此列选不为 p o i n t point point 的其他数字)
C a s e 3 : d p [ i ] [ j ] [ k ] = ( d p [ i ] [ j ] [ k ] + d p [ i − 1 ] [ j − 1 ] [ k ] ∗ a [ i ] [ p o i n t ] ) Case\ 3:dp[i][j][k] = (dp[i][j][k] + dp[i - 1][j - 1][k] * a[i][point]) Case 3:dp[i][j][k]=(dp[i][j][k]+dp[i−1][j−1][k]∗a[i][point])(此列选 p o i n t point point)
r e s res res 极为 d p [ n ] [ j ] [ k ] ( j > k ) dp[n][j][k](j>k) dp[n][j][k](j>k)
这时我们把总情况算出来:
d p [ i ] = d p [ i − 1 ] × ( p r e [ i ] + 1 ) dp[i]=dp[i-1]\times(pre[i]+1) dp[i]=dp[i−1]×(pre[i]+1)
d p [ n ] − − dp[n] -- dp[n]−−
答案即为 d p [ n ] − r e s dp[n]-res dp[n]−res。
但是这样是 O ( n 3 m ) O(n^3m) O(n3m),只有 84 p t s 84pts 84pts。
考虑优化:会发现,对于 j j j 和 k k k,我们不用考虑他们具体的数值,而考虑他们的相对关系即可。
于是定义 d p [ i ] [ j ] dp[i][j] dp[i][j] 为第 i i i 行, p o i n t point point 数量与其他数的数量差为 j j j 的方案数。
然后,你懂得吧。。。
Code
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define LL long long
using namespace std;
const int Mod = 998244353, MAXN = 105, MAXM = 2e3 + 5;
int a[MAXN][MAXM], n, m, t;
LL dp[MAXN][MAXN << 1], pre[MAXN], ans, res;
// 考虑优化:会发现,对于 j 和 k,我们不用考虑他们具体的数值,而考虑他们的相对关系即可。
int main() {
// freopen("meal.in", "r", stdin);
// freopen("meal.out", "w", stdout);
scanf("%d%d", &n, &m); ans = 1; t = n - (n / 2) - 1; dp[0][n] = 1;
for(int i = 1; i <= n; i ++) for(int j = 1; j <= m; j ++) scanf("%d", &a[i][j]), pre[i] += a[i][j], pre[i] %= Mod;
for(int i = 1; i <= n; i ++) ans *= (pre[i] + 1), ans %= Mod; ans --;
for(int point = 1; point <= m; point ++) {
for(int i = 1; i <= n; i ++) {
for(int j = 0; j <= (n << 1); j ++) dp[i][j] = dp[i - 1][j]; // 不选
for(int j = 0; j < (n << 1); j ++) dp[i][j] = (dp[i][j] + dp[i - 1][j + 1] * (pre[i] - a[i][point])) % Mod; // 不选 point
for(int j = 1; j <= (n << 1); j ++) dp[i][j] = (dp[i][j] + dp[i - 1][j - 1] * a[i][point]) % Mod; // 选 point
}
for(int j = n + 1; j <= (n << 1); j ++) res = (res + dp[n][j]) % Mod;
}
ans -= res; ans = (ans % Mod + Mod) % Mod; printf("%lld", ans);
return 0;
}