逆推期望
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pb(x) push_back(x)
const int maxn = 1e3+5;
const ll mod = 998244353;
struct node
{
ll x,y;
ll val;
bool operator < ( const node &b) const
{
return val < b.val;
}
};
node a[maxn*maxn];
ll sumr,sumr2,sumc,sumc2,sumdp;
//ll arr[maxn];
ll dp[maxn][maxn];
ll mul(ll a,ll b)
{
return (a*b)%mod;
}
ll ksm(ll a,ll b)
{
ll res = 1;
while(b > 0)
{
if(b & 1) res = mul(res,a);
a = mul(a,a);
b >>= 1;
}
return res;
}
ll add(ll a,ll b)
{
a += b;
while(a >= mod) a -= mod;
while(a < 0) a += mod;
return a;
}
ll inv(ll a)
{
ll ia = ksm(a,mod-2);
assert(mul(a,ia) == 1);
return ia;
}
int main()
{
ll n,m;
ll i,j,k;
ll len;
scanf("%lld %lld",&n,&m);
len = 0;
for(i=1;i<=n;++i)
{
for(j=1;j<=m;++j)
{
a[len].x = i;
a[len].y = j;
scanf("%lld",&a[len].val);
len ++;
}
}
sort(a,a+len);
//for(i=0;i<len;++i)
// printf("%lld %lld %lld\n",a[i].x,a[i].y,a[i].val);
memset(dp,0,sizeof(dp));
ll l,r;
l = 0;
sumr = sumr2 = sumc2 = sumc = sumdp = 0;
while(l < n*m)
{
r = l;
while(a[r].val == a[l].val && r < n*m) r ++;
//cout << l << " " << r << endl;
ll il = -1;
if(l != 0) il = inv(l);
for(i=l;i<r;++i)
{
ll rr,cc;
rr = a[i].x; cc = a[i].y;
if(il == -1)
{
dp[rr][cc] = 0;
continue;
}
dp[rr][cc] = add(dp[rr][cc],mul(sumdp,il));
dp[rr][cc] = add(dp[rr][cc],mul(rr,rr));
dp[rr][cc] = add(dp[rr][cc],mul(cc,cc));
dp[rr][cc] = add(dp[rr][cc],mul(sumr2,il));
dp[rr][cc] = add(dp[rr][cc],mul(sumc2,il));
dp[rr][cc] = add(dp[rr][cc],mul(mul(-2*rr,sumr),il));
dp[rr][cc] = add(dp[rr][cc],mul(mul(-2*cc,sumc),il));
}
for(i = l; i < r; ++i)
{
int rr,cc;
rr = a[i].x; cc = a[i].y;
sumdp = add(sumdp,dp[rr][cc]);
sumr2 = add(sumr2,mul(rr,rr));
sumc2 = add(sumc2,mul(cc,cc));
sumr = add(sumr,rr);
sumc = add(sumc,cc);
}
l = r;
}
ll c,b;
scanf("%lld %lld",&c,&b);
// cout << endl;
cout << dp[c][b] << endl;
}
/*
1 4
1 1 2 1
1 3
2 3
1 5 7
2 3 1
1 2
*/
这题是真的痛苦
从各个val低于指定位置val的点,向指定位置去推
至于为什么要用x、x²等前缀和,写下公式多看下就懂了