考虑从大往小枚举左上角
(
i
,
j
)
(i,j)
(i,j),求出所有能到达的点的权值和
s
u
m
[
i
]
[
j
]
sum[i][j]
sum[i][j]。
如果
(
i
,
j
+
1
)
(i,j+1)
(i,j+1)和
(
i
+
1
,
j
)
(i+1,j)
(i+1,j)中至多只有一个非障碍格子,那么容易计算。否则直接加上
s
u
m
[
i
]
[
j
+
1
]
sum[i][j+1]
sum[i][j+1]和
s
u
m
[
i
+
1
]
[
j
]
sum[i+1][j]
sum[i+1][j]会记重,考虑减掉两个点都能到达的。
考虑记录
m
a
x
n
[
i
]
[
j
]
[
k
]
maxn[i][j][k]
maxn[i][j][k]和
m
i
n
n
[
i
]
[
j
]
[
k
]
minn[i][j][k]
minn[i][j][k]表示
(
i
,
j
)
(i,j)
(i,j)能到达的点中,第
k
k
k行纵坐标最大和最小的,显然对于
(
i
,
j
+
1
)
(i,j+1)
(i,j+1)和
(
i
+
1
,
j
)
(i+1,j)
(i+1,j)都能到的行
k
k
k有
m
i
n
n
[
i
+
1
]
[
j
]
[
k
]
≤
m
i
n
n
[
i
]
[
j
+
1
]
[
k
]
minn[i+1][j][k]\leq minn[i][j+1][k]
minn[i+1][j][k]≤minn[i][j+1][k],
m
a
x
n
[
i
+
1
]
[
j
]
[
k
]
≤
m
a
x
n
[
i
]
[
j
+
1
]
[
k
]
maxn[i+1][j][k]\leq maxn[i][j+1][k]
maxn[i+1][j][k]≤maxn[i][j+1][k]。注意到路径相交可以交换,于是可以发现一些性质:如果第
k
k
k行有
(
i
,
j
+
1
)
(i,j+1)
(i,j+1)和
(
i
+
1
,
j
)
(i+1,j)
(i+1,j)都能到的点,那么必然包括
(
k
,
m
i
n
n
[
i
]
[
j
+
1
]
[
k
]
)
(k,minn[i][j+1][k])
(k,minn[i][j+1][k]),并且更好的性质是两个点都能到的点是若干个
m
i
n
n
[
i
]
[
j
+
1
]
[
k
]
minn[i][j+1][k]
minn[i][j+1][k]可达点集的并,且这些点集所占据的行和列不相交,于是容易在
O
(
N
)
\mathcal O(N)
O(N)的复杂度内求出它的大小。
这样实现是
O
(
N
3
)
\mathcal O(N^3)
O(N3)的,由于常数很小可以通过。标算貌似是
O
(
N
2
log
N
)
\mathcal O(N^2\log N)
O(N2logN)的,有时间再补。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int sumv[1505][1505],rpos[1505][1505];
int minn[2][1505][1505],maxn[2][1505][1505];
char str[1505][1505];
int main() {
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%s",str[i]+1);
int cur=0;
ll ans=0;
for(int i=n;i>0;i--) {
cur^=1;
for(int j=n;j>0;j--)
if (str[i][j]!='#') {
sumv[i][j]=sumv[i][j+1]+sumv[i+1][j]+str[i][j]-'0';
rpos[i][j]=max(max(rpos[i][j+1],rpos[i+1][j]),i);
minn[cur][j][i]=j;
for(int k=i+1;k<=rpos[i+1][j];k++) minn[cur][j][k]=minn[cur^1][j][k];
for(int k=max(rpos[i+1][j],i)+1;k<=rpos[i][j+1];k++) minn[cur][j][k]=minn[cur][j+1][k];
maxn[cur][j][i]=j;
for(int k=i;k<=rpos[i][j+1];k++) maxn[cur][j][k]=maxn[cur][j+1][k];
for(int k=max(rpos[i][j+1],i)+1;k<=rpos[i+1][j];k++) maxn[cur][j][k]=maxn[cur^1][j][k];
int d=i+1,r=min(rpos[i][j+1],rpos[i+1][j]);
while (d<=r) {
if (maxn[cur^1][j][d]>=minn[cur][j+1][d]) {
sumv[i][j]-=sumv[d][minn[cur][j+1][d]];
d=rpos[d][minn[cur][j+1][d]];
}
d++;
}
ans+=(ll)(str[i][j]-'0')*(sumv[i][j]-(str[i][j]-'0'));
}
}
printf("%lld\n",ans);
return 0;
}