先把权值排序,设排序后第
i
i
i 个位置坐标为
(
x
i
,
y
i
)
(x_i,y_i)
(xi,yi) ,权值为
a
i
a_i
ai
f
[
i
]
f[i]
f[i] 表示从第
i
i
i 个位置开始移动,直到无法移动的期望得分
很容易得出转移
f
[
i
]
=
1
c
n
t
i
∑
j
=
1
,
a
j
<
a
i
i
−
1
{
f
[
j
]
+
(
x
i
−
x
j
)
2
+
(
y
i
−
y
j
)
2
}
f[i]=\frac 1{cnt_i}\sum_{j=1,a_j<a_i}^{i-1}\{f[j]+(x_i-x_j)^2+(y_i-y_j)^2\}
f[i]=cnti1j=1,aj<ai∑i−1{f[j]+(xi−xj)2+(yi−yj)2}
其中
c
n
t
i
cnt_i
cnti 表示权值严格小于
a
i
a_i
ai 的位置个数
复杂度
O
(
n
2
m
2
)
O(n^2m^2)
O(n2m2)
强行推一波式子
f
[
i
]
=
1
c
n
t
i
∑
j
=
1
,
a
j
<
a
i
i
−
1
{
f
j
+
x
j
2
+
y
j
2
−
2
x
i
x
j
−
2
y
i
y
j
}
+
x
i
2
+
y
i
2
f[i]=\frac 1{cnt_i}\sum_{j=1,a_j<a_i}^{i-1}\{f_j+x_j^2+y_j^2-2x_ix_j-2y_iy_j\}+x_i^2+y_i^2
f[i]=cnti1j=1,aj<ai∑i−1{fj+xj2+yj2−2xixj−2yiyj}+xi2+yi2
=
1
c
n
t
i
(
∑
j
=
1
,
a
j
<
a
i
i
−
1
{
f
j
+
x
j
2
+
y
j
2
}
−
2
x
i
∑
j
=
1
,
a
j
<
a
i
i
−
1
x
j
−
2
y
i
∑
j
=
1
,
a
j
<
a
i
i
−
1
y
j
)
+
x
i
2
+
y
i
2
=\frac 1{cnt_i}(\sum_{j=1,a_j<a_i}^{i-1}\{f_j+x_j^2+y_j^2\}-2x_i\sum_{j=1,a_j<a_i}^{i-1}x_j-2y_i\sum_{j=1,a_j<a_i}^{i-1}y_j)+x_i^2+y_i^2
=cnti1(j=1,aj<ai∑i−1{fj+xj2+yj2}−2xij=1,aj<ai∑i−1xj−2yij=1,aj<ai∑i−1yj)+xi2+yi2
参与转移的
j
j
j 是一段前缀且
j
j
j 的最大值单调不降
可以维护参与转移的
j
j
j 的
f
j
+
x
j
2
+
y
j
2
f_j+x_j^2+y_j^2
fj+xj2+yj2 、
x
j
x_j
xj 、
y
j
y_j
yj 之和
复杂度
O
(
n
m
)
O(nm)
O(nm)
Code
#include<cmath>#include<cstdio>#include<cstring>#include<iostream>#include<algorithm>#define For(i, a, b) for (i = a; i <= b; i++)inlineintread(){int res =0;bool bo =0;char c;while(((c =getchar())<'0'|| c >'9')&& c !='-');if(c =='-') bo =1;else res = c -48;while((c =getchar())>='0'&& c <='9')
res =(res <<3)+(res <<1)+(c -48);return bo ?~res +1: res;}constint N =1e6+5, ZZQ =998244353;int n, m, f[N], inv[N];struct node
{int x, y, val;} a[N];inlineboolcomp(node x, node y){return x.val < y.val;}intmain(){int i, j, tn =0, r, c, p =1, sum2 =0, sumx =0, sumy =0, sum =0;
n =read(); m =read();
For (i,1, n) For (j,1, m)
c =read(), a[++tn]={i, j, c};
n = tn;
std::sort(a +1, a + n +1, comp);
r =read(); c =read();
For (i,1, n)if(a[i].x == r && a[i].y == c) n = i;
inv[1]=1;
For (i,2, n) inv[i]=1ll*(ZZQ - ZZQ / i)* inv[ZZQ % i]% ZZQ;
For (i,1, n){while(p < i && a[p].val < a[i].val){
sum2 =(1ll* a[p].x * a[p].x +1ll* a[p].y * a[p].y + sum2)% ZZQ;
sumx =(sumx + a[p].x)% ZZQ;
sumy =(sumy + a[p].y)% ZZQ;
sum =(sum + f[p])% ZZQ;
p++;}int delta =(1ll* a[i].x * a[i].x +1ll* a[i].y * a[i].y)% ZZQ *(p -1)% ZZQ;
delta =(delta + sum2)% ZZQ;
delta =(delta -2ll* a[i].x * sumx % ZZQ + ZZQ)% ZZQ;
delta =(delta -2ll* a[i].y * sumy % ZZQ + ZZQ)% ZZQ;
f[i]=1ll*(delta + sum)* inv[p -1]% ZZQ;}
std::cout << f[n]<< std::endl;return0;}