题目链接
分析
我这里用的是 n3logm 的算法。
我们可以吧期望问题转化为计数问题,对于任意一个子矩阵来说,对于同种颜色的格子我们只计算从左到右,从上到下的第一个,那么我们反过来对于每个格子的贡献其实就是这个格子的颜色作为从左到右从上到下第一个的时候这样的矩阵有多少个。
对于格子
(i,j)
颜色为
c
我们先求出从
枚举每一行的边界,对于一个包含
(i,j)
的子矩阵来说,固定它的上两个端点在 第
k
行,找出从
枚举每个格子的时间是 n2
所以可以在 nlogm 以内就解决了.
AC code
//Problem : 6052 ( To my boyfriend ) Judge Status : Accepted
//RunId : 21383023 Language : G++ Author : zouzhitao
//Code Render Status : Rendered By HDOJ G++ Code Render Version 0.01 Beta
#include<bits/stdc++.h>
#define pb push_back
#define mp make_pair
#define PI acos(-1)
#define fi first
#define se second
#define INF 0x3f3f3f3f
#define INF64 0x3f3f3f3f3f3f3f3f
#define random(a,b) ((a)+rand()%((b)-(a)+1))
#define ms(x,v) memset((x),(v),sizeof(x))
using namespace std;
const int MOD = 1e9+7;
const double eps = 1e-8;
typedef long long LL;
typedef long double DB;
typedef pair<int,int> Pair;
const int maxn = 1e2+10;
int a[maxn][maxn];
Pair col[maxn][maxn];
int upper[maxn],lower[maxn];
int n,m;
bool cmp(const Pair & a,const Pair & b){
if(a.fi == b.fi )return a.se < b.se;
else return a.fi < b.fi;
}
int main() {
// std::ios::sync_with_stdio(false);
// std::cin.tie(0);
int T;scanf("%d",&T );;
while (T--) {
scanf("%d%d",&n,&m );
LL ans= 0;
for(int i=1 ; i<=n ; ++i){
for(int j=1 ; j<=m ; ++j){
scanf("%d",&a[i][j] );
col[i][j] = mp(a[i][j],j);
}
sort(col[i]+1,col[i]+m+1);
col[i][m+1] = mp(INF,m+1);
}
// for(int i=1 ; i<=n ; ++i){
// std::cout << i << '\n';
// for(int j =1 ; j<=m ; ++j)
// std::cout << col[i][j].fi <<" "<<col[i][j].se << '\n';
// }
for(int i=1 ; i<=n ; ++i){
for(int j=1; j<=m ; ++j){
int color = a[i][j];
ms(upper,0);
ms(lower,INF);
lower[i] = m;
int p = lower_bound(col[i]+1,col[i]+m+1,mp(color,j),cmp) - col[i];
upper[i] = 1;
p--;
if(p>0 && col[i][p].fi == color)upper[i] = col[i][p].se+1;
for(int k = i-1 ; k>=1 ; --k){
Pair point = *lower_bound(col[k]+1,col[k]+m+1,mp(color,j));
if(point.fi == color)lower[k] = point.se-1;
lower[k] = min(lower[k],lower[k+1]);
Pair * pp = upper_bound(col[k]+1,col[k]+m+1,mp(color,j));
pp--;
point = *pp;
if(point.fi == color)upper[k] = point.se+1;
upper[k] = max(upper[k],upper[k+1]);
}
//std::cout << i<<" ij "<<j << '\n';
for(int k=1 ; k<=i ; ++k){
// std::cout << "k = "<<k << '\n';
// std::cout << upper[k]<<" " << lower[k] << '\n';
if(upper[k]>j || lower[k] < j)continue;
LL left = j-upper[k]+1;
LL right = lower[k] - j +1;
//std::cout << "left "<<left<<"right "<<right <<"add "<<left*right*(n-i+1)<< '\n';
ans += left*right*(n-i+1);
}
}
}
// std::cout << ans << '\n';
LL tot = (LL)n*(n+1)*m*(m+1);
printf("%.9lf\n",ans*4.0/tot );
}
return 0;
}