题意:
题解:
感觉挺妙的。
对于任意矩阵都有 1*1 - 1*2 - 2*1 + 2*2 = 1。我们用这个来做容斥就好了。
对于每个点被覆盖的次数,可以用单调队列。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=2e3+50, mod=998244353;
inline int add(int x,int y) {return (x+y>=mod) ? (x+y-mod) : (x+y);}
inline int dec(int x,int y) {return (x-y<0) ? (x-y+mod) : (x-y);}
inline int mul(int x,int y) {return (LL)x*y%mod;}
inline int power(int a,int b,int rs=1) {for(;b;b>>=1,a=mul(a,a)) if(b&1) rs=mul(rs,a); return rs;}
char mp[N][N];
int n,m,k,ans;
int s0[N][N],s1[N][N],s2[N][N],s3[N][N];
LL ss[N][N];
template <typename T>
inline void revx(T a[][N]) {
for(int i=1;i<=n/2;i++)
for(int j=1;j<=m;j++)
swap(a[i][j],a[n-i+1][j]);
}
template <typename T>
inline void revy(T a[][N]) {
for(int i=1;i<=n;i++)
for(int j=1;j<=m/2;j++)
swap(a[i][j],a[i][m-j+1]);
}
inline void solve(int *h,int *s) {
static int l[N],v[N],sum,tl;
sum=tl=0;
for(int i=1;i<=m;i++) {
int len=1;
while(tl && v[tl]>=h[i]) len+=l[tl], sum=dec(sum,mul(v[tl],l[tl])), --tl;
v[++tl]=h[i]; l[tl]=len; sum=add(sum,mul(v[tl],l[tl])); s[i]=sum;
}
}
inline void solve(char o[][N],int s[][N]) {
static int h[N][N];
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
h[i][j]=(o[i][j]=='0') ? 0 : (h[i-1][j]+1);
for(int i=1;i<=n;i++)
solve(h[i],s[i]);
}
inline int calc_sum(LL rs=0) {
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++) {
ss[i][j]=ss[i-1][j]+ss[i][j-1]+ss[i][j]-ss[i-1][j-1];
rs=rs+power(ss[i][j]%mod,k);
}
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++) ss[i][j]=0;
return rs%mod;
}
int main() {
cin>>n>>m>>k;
for(int i=1;i<=n;i++) cin>>(mp[i]+1);
solve(mp,s3);
revx(mp);
solve(mp,s1);
revx(s1);
revy(mp);
solve(mp,s0);
revx(s0); revy(s0);
revx(mp);
solve(mp,s2);
revy(s2);
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++) {
ss[i][j]+=s0[i][j];
ss[i+1][j]-=s2[i][j];
ss[i][j+1]-=s1[i][j];
ss[i+1][j+1]+=s3[i][j];
}
ans=add(ans,calc_sum());
memset(ss,0,sizeof(ss));
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++) {
ss[i][j]+=s0[i][j];
ss[i][j]-=s1[i][j];
ss[i+1][j]-=s2[i][j];
ss[i+1][j]+=s3[i][j];
}
ans=dec(ans,calc_sum());
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++) {
ss[i][j]+=s0[i][j];
ss[i][j+1]-=s1[i][j];
ss[i][j]-=s2[i][j];
ss[i][j+1]+=s3[i][j];
}
ans=dec(ans,calc_sum());
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++) {
ss[i][j]+=s0[i][j];
ss[i][j]-=s1[i][j];
ss[i][j]-=s2[i][j];
ss[i][j]+=s3[i][j];
}
ans=add(ans,calc_sum());
cout<<ans<<'\n';
}