Description
给出一个n*m矩阵A,问所有子矩阵中不同数字个数的期望
Input
第一行一整数T表示用例组数,每组用例首先输入两个整数n和m表示矩阵行列数,之后输入一n*m矩阵A[i][j] (T<=9,1<=n,m<=100,0<=A[i][j]<=n*m)
Output
输出子矩阵中不同数字个数的期望,结果保留到小数点后九位
Sample Input
1
2 3
1 2 1
2 1 2
Sample Output
1.666666667
Solution
问题在于统计所有子矩阵中不同数字个数之和,有两种方法,第一种是容斥原理,对于一个数字,统计该数字对答案的贡献,即统计所有包含该种数字的子矩阵个数,该个数=至少经过一个该种数字的矩阵个数-至少经过两个该种数字的矩阵个数+至少经过三个该种数字的矩阵个数-…,找到所有该种颜色的点,枚举该点集的一个子集,统计所有经过该子集的子矩阵(一个矩阵可以用左上角和右下角确定,我们找到这个点集的最小最大横纵坐标值得到两个点,这里两个点构成的矩阵是包含这个点集的最小矩阵,进而所有包含这个矩阵的矩阵均满足条件),用容斥原理即得到经过该种颜色点的子矩阵个数,对一个颜色该种方法的时间复杂度为O(s*2^s),s为该种颜色点的数量。第二种是用单调栈去找所有不经过该种颜色点的矩阵,从第一行开始每行从左到右统计,对于一个点统计以该点为右下角的不经过该种颜色点的矩阵个数,需要另开一个数组c[i][j]表示第i行第j个元素往上到第一个该种颜色的点的距离,例如三个c[i][j]为1 2 1的话,以第一个点为右下角的矩阵只有1个,以第二个点为右下角的矩阵占据前两列的有1个,占据第二列的有2个,以第三个点为右下角的矩阵占据前三列的有1个,占据第二三列的有1个,占据第三列的有1个,即每次以第i行第j列元素为右下角时,占据第k~j列的矩阵会有min(c[i][k],…,c[i][j])个,用单调栈维护c[i][j]的值和栈和,当前的c[i][j]如果比栈顶元素大则直接进栈,否则把栈中所有大于c[i][j]的元素全部变成c[i][j],每次进栈之后把栈和累计到答案里即为以当前这个元素为右下角的不经过该种元素的矩阵个数,这样做统计一种颜色答案的时间复杂度是O(n^2)。这两种方法单独拿出一种对该题均不适合,但是如果把元素按取值分类,出现次数大于9的值(至多n^2/9种不同取值)用第二种方法,不大于9的用第一种方法,这样时间复杂度为
Code
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
typedef long long ll;
typedef pair<int,int>P;
const int maxn=10001;
vector<P>vec[maxn];
int T,n,m,w[101][101],vis[maxn],a[101],b[101],c[101][101];
int Solve1(int x)
{
int nn=vec[x].size(),N=1<<nn,ans=0;
for(int i=1;i<N;i++)
{
int flag=0,l1=n+1,r1=m+1,l2=0,r2=0;
for(int j=0;j<nn;j++)
if(i&(1<<j))
{
flag^=1;
l1=min(l1,vec[x][j].first),r1=min(r1,vec[x][j].second),
l2=max(l2,vec[x][j].first),r2=max(r2,vec[x][j].second);
}
if(flag)ans+=l1*r1*(n-l2+1)*(m-r2+1);
else ans-=l1*r1*(n-l2+1)*(m-r2+1);
}
return ans;
}
int Solve2(int x)
{
int ans=0;
for(int i=1;i<=n;i++)
{
int p=0,sum=0;
for(int j=1;j<=m;j++)
{
if(w[i][j]!=x)c[i][j]=c[i-1][j]+1;
else c[i][j]=0;
int num=1;
while(p&&a[p]>=c[i][j])
num+=b[p],sum-=a[p]*b[p],p--;
a[++p]=c[i][j],b[p]=num,sum+=a[p]*b[p],ans+=sum;
}
}
return ans;
}
int main()
{
scanf("%d",&T);
while(T--)
{
memset(vis,0,sizeof(vis));
for(int i=1;i<maxn;i++)vec[i].clear();
scanf("%d%d",&n,&m);
int index=0;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
{
scanf("%d",&w[i][j]);
if(!vis[w[i][j]])vis[w[i][j]]=++index;
w[i][j]=vis[w[i][j]];
vec[w[i][j]].push_back(P(i,j));
}
double ans=0,sum=n*(n+1)/2*m*(m+1)/2;
for(int i=1;i<=index;i++)
if(vec[i].size()<=9)ans+=Solve1(i)/sum;
else ans+=(sum-Solve2(i))/sum;
printf("%.9f\n",ans);
}
return 0;
}