题目描述
Emiya 是个擅长做菜的高中生,他共掌握 n� 种烹饪方法,且会使用 m� 种主要食材做菜。为了方便叙述,我们对烹饪方法从 1∼n1∼� 编号,对主要食材从 1∼m1∼� 编号。
Emiya 做的每道菜都将使用恰好一种烹饪方法与恰好一种主要食材。更具体地,Emiya 会做 ai,j��,� 道不同的使用烹饪方法 i� 和主要食材 j� 的菜(1≤i≤n1≤�≤�、1≤j≤m1≤�≤�),这也意味着 Emiya 总共会做 ∑i=1n∑j=1mai,j∑�=1�∑�=1���,� 道不同的菜。
Emiya 今天要准备一桌饭招待 Yazid 和 Rin 这对好朋友,然而三个人对菜的搭配有不同的要求,更具体地,对于一种包含 k� 道菜的搭配方案而言:
- Emiya 不会让大家饿肚子,所以将做至少一道菜,即 k≥1�≥1
- Rin 希望品尝不同烹饪方法做出的菜,因此她要求每道菜的烹饪方法互不相同
- Yazid 不希望品尝太多同一食材做出的菜,因此他要求每种主要食材至多在一半的菜(即 ⌊k2⌋⌊�2⌋ 道菜)中被使用
这里的 ⌊x⌋⌊�⌋ 为下取整函数,表示不超过 x� 的最大整数。
这些要求难不倒 Emiya,但他想知道共有多少种不同的符合要求的搭配方案。两种方案不同,当且仅当存在至少一道菜在一种方案中出现,而不在另一种方案中出现。
Emiya 找到了你,请你帮他计算,你只需要告诉他符合所有要求的搭配方案数对质数 998,244,353998,244,353 取模的结果。
输入
第 1 行两个用单个空格隔开的整数 n,m�,�。
第 2 行至第 n+1�+1 行,每行 m� 个用单个空格隔开的整数,其中第 i+1�+1 行的 m� 个数依次为 ai,1,ai,2,⋯,ai,m��,1,��,2,⋯,��,�。
输出
仅一行一个整数,表示所求方案数对 998,244,353998,244,353 取模的结果。
样例输入 复制
2 3 1 0 1 0 1 1
样例输出 复制
3
提示
样例输入2
3 3 1 2 3 4 5 0 6 0 0
样例输出2
190
【样例 1 解释】
由于在这个样例中,对于每组 i,j�,�,Emiya 都最多只会做一道菜,因此我们直接通过给出烹饪方法、主要食材的编号来描述一道菜。
符合要求的方案包括:
- 做一道用烹饪方法 1、主要食材 1 的菜和一道用烹饪方法 2、主要食材 2 的菜
- 做一道用烹饪方法 1、主要食材 1 的菜和一道用烹饪方法 2、主要食材 3 的菜
- 做一道用烹饪方法 1、主要食材 3 的菜和一道用烹饪方法 2、主要食材 2 的菜
因此输出结果为 3mod998,244,353=33mod998,244,353=3。 需要注意的是,所有只包含一道菜的方案都是不符合要求的,因为唯一的主要食材在超过一半的菜中出现,这不满足 Yazid 的要求。
【样例 2 解释】
Emiya 必须至少做 2 道菜。
做 2 道菜的符合要求的方案数为 100。
做 3 道菜的符合要求的方案数为 90。
因此符合要求的方案数为 100 + 90 = 190。
【数据范围】
测试点编号 | n=�= | m=�= | ai,j<��,�< | 测试点编号 | n=�= | m=�= | ai,j<��,�< |
---|---|---|---|---|---|---|---|
11 | 22 | 22 | 22 | 77 | 1010 | 22 | 103103 |
22 | 22 | 33 | 22 | 88 | 1010 | 33 | 103103 |
33 | 55 | 22 | 22 | 9∼129∼12 | 4040 | 22 | 103103 |
44 | 55 | 33 | 22 | 13∼1613∼16 | 4040 | 33 | 103103 |
55 | 1010 | 22 | 22 | 17∼2117∼21 | 4040 | 500500 | 103103 |
66 | 1010 | 33 | 22 | 22∼2522∼25 | 100100 | 2×1032×103 | 998244353998244353 |
对于所有测试点,保证 1≤n≤1001≤�≤100,1≤m≤20001≤�≤2000,0≤ai,j<998,244,3530≤��,�<998,244,353。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int _ = 100 + 10;
const int __ = 2000 + 10;
int n, m, A[_][__];
ll sum[_], f[_][_ << 1], tmp, ans = 1;
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
scanf("%d", &A[i][j]);
sum[i] = (sum[i] + A[i][j]) % mod;
}
ans = ans * (sum[i] + 1) % mod;
}
ans = (ans - 1 + mod) % mod;
for (int k = 1; k <= m; ++k) {
memset(f, 0, sizeof(f));
f[0][n] = 1;
for (int i = 1; i <= n; ++i) {
for (int j = -i + n; j <= i + n; ++j) {
f[i][j] = (f[i - 1][j] + f[i - 1][j - 1] * A[i][k] % mod + f[i - 1][j + 1] * (sum[i] - A[i][k]) % mod) % mod;
if (i == n && j > n) tmp = (tmp + f[i][j]) % mod;
}
}
}
ans = (ans - tmp + mod) % mod;
printf("%lld\n", ans);
return 0;
}