题意
这个题的意思就是说,每一行只能选一个(你可以选你也可以不选,但是你只能选一个),你选出来的种类数 n n n,其中 n / 2 n/2 n/2不能是同一列的东西,请问你有多少种选法。
思路
这道题,其实容斥的思想非常简单,主要还是dp来计数比较复杂。
容斥的思想其实就是正难则反,我记录全部的合法数量,然后减去我们的不合法数量就可以。
我们的合法数量很好计算,我们计算每一行的前缀和
s
u
m
[
i
]
sum[i]
sum[i]那么一共有
s
u
m
[
i
]
+
1
sum[i]+1
sum[i]+1种选法,(不选的情况),然后根据乘法原理,求可以求出全部的情况了,值得注意的是,全部的情况要减去一种,代表的是什么也不选的一个情况。
那么我们现在来计算一下不满足题意的情况,首先通过分析,我们可以轻松的得出,这个一定是一个子状态的问题,而且没有后效性,也就是后面的选法是根据前面的选法推出来的,而且不管你前面选的是什么后面一定不会因为你前面选的东西而改变这个情况。
不满足题意的情况那就是我在同一列选的东西超过了我全部的选的
k
k
k,那么也就是代表的是,我们现在当前这一列选了
k
/
2
+
1
k/2+1
k/2+1个以后,在其他列里面无论怎么选,都是符合要求的,也就是我们只用枚举列数就可以。
我们可以开三个维度的dp数组,第一个维度代表当前枚举到的行,第二个维度代表的是当前列,我选了多少个,第三个维度代表的是其他列我选了多少个,dp数组的含义就是我当前遍历到的这一行且当前列选了j个,其他列选了k个的一共的种类数。
那么状态转移方程就是
d
p
[
i
]
[
j
]
[
k
]
=
d
p
[
i
−
1
]
[
j
]
[
k
]
dp[i][j][k]=dp[i-1][j][k]
dp[i][j][k]=dp[i−1][j][k]
这个状态就是我什么也不选
如果
j
>
0
j>0
j>0
d
p
[
i
]
[
j
]
[
k
]
+
=
d
p
[
i
−
1
]
[
j
−
1
]
[
k
]
∗
a
[
i
]
[
u
]
dp[i][j][k]+=dp[i-1][j-1][k]*a[i][u]
dp[i][j][k]+=dp[i−1][j−1][k]∗a[i][u]
代表的是我当前列要选
j
j
j个数,但是上一个状态只选了
j
−
1
j-1
j−1个因此我这一次要选
a
[
i
]
[
j
]
a[i][j]
a[i][j]这个状态的(当前列是
u
u
u)
如果
k
>
0
k>0
k>0
d
p
[
i
]
[
j
]
[
k
]
+
=
d
p
[
i
−
1
]
[
j
]
[
k
−
1
]
∗
(
s
u
m
[
i
]
−
a
[
i
]
[
u
]
)
dp[i][j][k]+=dp[i-1][j][k-1]*(sum[i]-a[i][u])
dp[i][j][k]+=dp[i−1][j][k−1]∗(sum[i]−a[i][u])
代表的是我其他列要选
k
k
k个数,但是我上一个状态只选了
k
−
1
k-1
k−1个数,那么我当前的这一列要选其他数。
#include<iostream>
#include<cstring>
using namespace std;
#define int long long
#define mod 998244353
int dp[51][101][101];
int a[2100][2100];
int sum[210];
signed main(){
int n, m;
cin >> n >> m;
for (int i = 1; i <= n;i++){
for (int j = 1; j <= m;j++){
cin >> a[i][j];
sum[i] += a[i][j];
sum[i] %= mod;
}
}
int tot = 0;
for (int u = 1; u <= m;u++){
memset(dp, 0, sizeof dp);
dp[0][0][0] = 1;
for (int i = 1; i <= n;i++){
for (int j = 0; j <= i;j++){
for (int k = 0; k + j <= i;k++){
dp[i][j][k] = dp[i - 1][j][k];
if(j){
dp[i][j][k] += dp[i - 1][j - 1][k] * a[i][u];
dp[i][j][k] %= mod;
}
dp[i][j][k] %= mod;
if(k){
dp[i][j][k] += dp[i - 1][j][k - 1] * (sum[i] - a[i][u]);
dp[i][j][k] %= mod;
}
dp[i][j][k] %= mod;
}
}
}
for (int i = 1; i <= n;i++){
for (int j = 0; j + i <= n;j++){
if(i>j){
tot += dp[n][i][j];
tot %= mod;
}
}
}
}
int ans = 1;
for (int i = 1; i <= n;i++){
int sum2 = 0;
for (int j = 1; j <= m;j++){
sum2 += a[i][j];
sum2 %= mod;
}
sum2++;//*********
ans *= sum2;
ans %= mod;
}
ans--;
//cout << ans << " " << tot << endl;
cout << (ans - tot+mod) % mod << endl;
}
当然我们这样是一定超时的,因此我们可以使用差值dp来缩小空间的范围。(+n)是防止出现负数。
#include<iostream>
#include<cstring>
using namespace std;
#define int long long
#define mod 998244353
//int dp[51][101][101];
int dp[101][500];
int a[101][2100];
int sum[210];
signed main(){
int n, m;
cin >> n >> m;
for (int i = 1; i <= n;i++){
for (int j = 1; j <= m;j++){
cin >> a[i][j];
sum[i] += a[i][j];
sum[i] %= mod;
}
}
int tot = 0;
for (int u = 1; u <= m;u++){
memset(dp, 0, sizeof dp);
dp[0][n] = 1;
for (int i = 1; i <= n;i++){
for (int j = -i; j <= i;j++){
dp[i][j + n] = dp[i - 1][j + n] + dp[i - 1][j + n-1] * a[i][u] + dp[i - 1][j + 1 + n] * (sum[i] - a[i][u]);
dp[i][j+n] %= mod;
}
}
for (int i = n + 1; i <= 2 * n; i++)
{
tot += dp[n][i];
tot %= mod;
}
}
//cout << tot << endl;
int ans = 1;
for (int i = 1; i <= n;i++){
int sum2 = 0;
for (int j = 1; j <= m;j++){
sum2 += a[i][j];
sum2 %= mod;
}
sum2++;
ans *= sum2;
ans %= mod;
}
ans--;
cout << (ans - tot + mod) % mod << endl;
}