别骂了别骂了真不会状压
题目描述
有一个 n × m n \times m n×m 的矩阵 A A A,求有多少个 A A A 的非空子矩阵满足每行每列要么升序要么降序。
一个矩阵的非空子矩阵是指,选出行的一个非空子集与列的一个非空子集,既在这些行又在这些列的元素排成的新矩阵。一个 n × m n\times m n×m 的矩阵的非空子矩阵数为 ( 2 n − 1 ) ( 2 m − 1 ) ( 2^n-1 )( 2^m-1 ) (2n−1)(2m−1) 。
n , m ≤ 20 n,m≤20 n,m≤20,保证 A A A 中所有元素构成一个排列,时限 4s。
题解
看到数据范围,八成是个状压。
我们按行考虑,发现填到某一行时转移只跟上一个被选的行和此时是升序/降序有关。
于是我的想法是, f i , j , k , l f_{i,j,k,l} fi,j,k,l 表示考虑到第 i i i 行且第 i i i 行必选,上一行为 j j j,当前列状态为 k k k(1为选这一列,0为不选),每一列升降序状态为 l l l (1表示升序,0表示降序)的方案数,转移时还需要判断是否合法。然后就发现复杂度达到了惊人的 O ( n 3 2 2 n ) O(n^32^{2n}) O(n322n) ,倒闭。
不难发现其实只需要知道上一个选的行和上上个选的行就可以判断出升降序,于是我们可以把一个 O ( 2 n ) O(2^{n}) O(2n) 优化成 O ( n ) O(n) O(n),复杂度变成了 O ( n 4 2 n ) O(n^42^n) O(n42n) 。进步了不少,但仍然无法通过。
于是我就在这里卡了非——常——非——常——久——
经过不懈的努力思考向dalao求助,我们发现有个很简单快速的判断每列是否合法的方法。那就是先
O
(
n
2
)
O(n^2)
O(n2) 预处理出所有行两两之间的列大小关系,判断只需要
O
(
1
)
O(1)
O(1) and一下就可以了。
具体来说,预处理出数组 a a a,其中将 a i , j a_{i,j} ai,j 转化为2进制后,若是第 t t t 位为1,代表 A i , t > A j , t A_{i,t}>A_{j,t} Ai,t>Aj,t,否则 A i , t < A j , t A_{i,t}<A_{j,t} Ai,t<Aj,t。转移时将当前列状态 k k k 分别 and 上 a i , l s t a_{i,lst} ai,lst 和 a l s t , l s t t a_{lst,lstt} alst,lstt, l s t lst lst 和 l s t t lstt lstt 分别表示上一个取的行和上上个取的行)再判断两数是否相等即可。
最终复杂度 O ( n 3 2 n ) O(n^32^n) O(n32n),常数较小,可以通过本题。
虽然机房大佬人均一眼秒,但对于我这种不擅长dp尤其是状压dp的萌新来说最后这个思路还是太高妙了/se。
你以为到这就结束了?并没有,写完以上之后我仍然调了很久才过这题,因为犯了一个很憨的小错误()枚举当前行
i
i
i 时还需要判断
i
i
i 这行在当前列状态下是否满足行单调的限制,若是不满足则直接跳过。一直思考列限制久了不知不觉就忘了还有行限制这回事嘻嘻。
Code
//代码中一些变量名称和题解中不同,请注意甄别
#include<bits/stdc++.h>
using namespace std;
const int N=20+5;
typedef long long ll;
int n,m,now,a[N][N],g[N][N];
ll ans,f[N][N];
int zf(int x){
if(x>0) return 1;
else return -1;
}
void dfs(int x){ //个人喜好,用dfs枚举状态
if(x==m+1){
if(now==0) return;
memset(f,0,sizeof(f));
for(int i=1;i<=n;i++){
int lst=0,lstt=0;
bool ck=true;
for(int j=1;j<=m;j++){
if(now&(1<<j-1)){
if(lstt!=0){
if(zf(a[i][j]-a[i][lst])!=zf(a[i][lst]-a[i][lstt])){
ck=false;
break;
}
}
lstt=lst,lst=j;
}
}
if(!ck) continue;
f[i][0]=1;
ans+=f[i][0];
for(int j=1;j<i;j++){
for(int k=0;k<j;k++)
if((g[j][k]&now)==(g[i][j]&now)||k==0)
f[i][j]+=f[j][k];
ans+=f[i][j];
}
}
return;
}
dfs(x+1);
now|=(1<<x-1);
dfs(x+1);
now^=(1<<x-1);
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
scanf("%d",&a[i][j]);
for(int k=1;k<i;k++)
if(a[i][j]>a[k][j]) g[i][k]|=(1<<j-1);
}
}
dfs(1);
printf("%lld",ans);
return 0;
}