题目
n*m(n,m<=2e3)的矩阵,
每个位置,#表示初始已经是黑块,.表示初始是白块
你可以将任意个白块染色成黑块,使得矩阵成为美丽矩阵
美丽矩阵:
对于任意在矩阵中的(i,j),即1<=i<=n且1<=j<=m,
如果(i,j)是黑块
①若(i+1,j)也在矩阵中,则(i+1,j)也是黑块
②若(i+1,j+1)也在矩阵中,则(i+1,j+1)也是黑块
求染色的方案数,答案对998244353取模
思路来源
一个5年前打gym的题,第一印象是很像,然后发现确实很像
Grid Coloring - Gym 101615J - Virtual Judge
题解
dp[i][j]表示第j列从下到上,最后一个黑块位于第i行的时候的方案数
转移从右到左,从下到上,也是先把#先补全,
然后从每一列初始局面最上面那个黑块开始才有值
而官方题解的状态,则是和对角线对齐
这个状态是怎么定义出来的,我思考了一下,
dp其实就是这么几个因素,感觉需要多想清楚这么几个问题
1. 状态怎么定义
相当于有两个方向的链,为了保证无后效性,肯定需要枚举链取多长,
所以,dp状态需要有一维和一条链平行,而另一条链则通过转移的时候控制合法性
因为(i,j)黑块取了的话,这一列下面的黑块都会取,
所以该把第i列最终取了几个黑块作为定义,也就是最后一个黑块位于哪一行
2. 转移的顺序
本题从左到右或从右到左都可以、从上到下或从下到上也都可以,
但是,两个for循环,一定是列的这一维在外层,
从右到左:只有处理好了第i+1列,才能处理第i列,(i,j)涂黑当且仅当(i+1,k)涂黑,k<=j+1
从左到右:只有处理好了第i-1列,才能处理第i列,(i,j)涂黑当且仅当(i+1,k)不涂黑,k<j-1
3. 什么时候有值(什么状态是合法的)
由于初始局面的黑块必取,
那未取走初始局面黑块的就是非法状态,值为0
4. 答案如何计算
最后一列决策完之后,局面就是唯一的,
所以,对最后一列的dp答案求和
5. 需不需要优化,如何优化
暴力是O(n*m*n)的,做个前缀和优化就可以了
6. corner case有哪些
某一列没有黑块的情形,此时,我把最后一个黑块定到了第n+1行,
第i行又需要下一列的第i+1行,所以前缀和做到第n+2行
如果把dp状态定义为,
第j列自下到上总共连着取了i(0<=i<=n)个黑块,
可能转移的时候会更加自然
代码
#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
typedef long long ll;
typedef double db;
typedef pair<int,int> P;
#define fi first
#define se second
#define pb push_back
#define dbg(x) cerr<<(#x)<<":"<<x<<" ";
#define dbg2(x) cerr<<(#x)<<":"<<x<<endl;
#define SZ(a) (int)(a.size())
#define sci(a) scanf("%d",&(a))
#define pt(a) printf("%d",a);
#define pte(a) printf("%d\n",a)
#define ptlle(a) printf("%lld\n",a)
#define debug(...) fprintf(stderr, __VA_ARGS__)
std::mt19937_64 gen(std::chrono::system_clock::now().time_since_epoch().count());
ll get(ll l, ll r) { std::uniform_int_distribution<ll> dist(l, r); return dist(gen); }
const int N=2e3+10,mod=998244353;
int n,m,dp[N][N],sum[N],tp[N];
char s[N][N];
void add(int &x,int y){
x=(x+y)%mod;
}
int main(){
sci(n),sci(m);
rep(i,1,n){
scanf("%s",s[i]+1);
}
rep(j,1,m)tp[j]=n+1;
rep(i,1,n){
rep(j,1,m){
if(s[i][j]=='#'){
s[i+1][j]='#';
s[i+1][j+1]='#';
tp[j]=min(tp[j],i);
}
}
}
per(j,m,1){
per(i,tp[j],1){
if(j==m)dp[i][j]=1;
else dp[i][j]=sum[i+1];
}
memset(sum,0,sizeof sum);
rep(i,1,n+2){
sum[i]=(sum[i-1]+dp[i][j])%mod;
}
}
pte(sum[n+1]);
return 0;
}
/*
8
3 7 4 7 3 3 8 2
*/