题意:给出一个n*m(1<=n,m<=100)的矩阵,每个矩阵元素有一个颜色值ai(0<=ai<=10000),现在定义一个子矩阵的value为子矩阵中不同颜色的数量,求子矩阵value的期望。
题解:期望=所有子矩阵的value总和/子矩阵个数。首先解决一个zz的问题:子矩阵有多少个,二重循环枚举子矩形右下角的点(i,j),那么子矩阵左上角的点一定在(0,0)-(i,j)这个矩阵里,所以有i*j种。
如果你学过小学数学,当然,子矩阵总共有n*(n+1)*m*(m+1)/4=[n*(n+1)/2]*[m*(m+1)/2]=C(n+1,2)*C(m+1,2)种。(感谢 小坏蛋_千千 的纠正)
***************在此感谢UFO__的评论***************
接下来我们讲重点:显然要每个颜色单独考虑,在考虑颜色i的时候,把颜色i的点看作关键点,求出 至少包含一个关键点的子矩阵个数。现在的问题是我们如何不重复不遗漏的统计个数。emmm先排个序(行号升序,列号升序)。把每个合法的矩阵算在序最小的那个关键点头上,这样就可以保证不重复,不遗漏。那么我们再找包含第一个关键点的矩阵的时候,显然没有任何限制,只需要包含这个点就行了。找第二个关键点的矩阵的时候,不能包含第一个点……找第i个关键点决定的矩阵的时候,不能包含1..i-1这i-1个点。
那么假设我们的图长这个样子,且灰色点是已经计数完成的点,白色点是未计数的点,黑色点是正在计数的点。
这个黑点位于(4,4)位置,我们现在要做的就是确定上下左右四个边界分别有多少种选取方式,显然白色点的序大于黑色点,黑色点的矩阵可以包含他们也可以不包含他们,因此下边界无限制,可以取到(4,5,6==n)三种方式,而由于黑点同一行的左边(同行右边的也是白色点)以及上边的行有不能包含的点,因此一个矩阵上边界对应了左右的最远边界。我们需要枚举矩阵的上边界(4,3,2,1),并且因为上边界i-1的时候上边界i要考虑的点仍然要考虑,而且要额外考虑i-1这行的灰色点。因此左右最远边界要持续进行维护。上边界是ii的时候,考虑所有ii行的灰色点,在黑点左边的去更新左边界最远点,在黑店右边的去更新有边界最远点,而刚好在黑点上边的话,说明这一行不可能作为上边界了。这个时候就结束计算黑点名下的子矩形。开始计算下一个白点。
优化:我们看第10列第1行和第3行的两个同列的灰色点,显然在上边界为3的时候,靠下的这个灰色点决定了右边界最远端,而上边界继续向上移动,右边界只可能更小,因此靠上的这个灰色点其实什么作用也没有。因此同一列只有最下边一个点有用。这样让我们在计算黑点的时候,最多只考虑之前出现的m个点,而不是理论最坏im个点,这个优化还是很关键的。因此当我们枚举(i,j)名下的矩形,上边界是xi行的时候,右边界最远是ry,左边界最远是ly,那么组合一下就得到了(n-i+1)*(j-ly+1)*(ry-j+1)个贡献(下边界方案*左边界方案*右边界方案,上边界已经确定是xi)。上边讲的一个跳出条件仍然有效。
整个算法复杂度严格小于 mn*(n/2+m)(点数*平摊上边界枚举数量&最多需要考虑的点),实际上跑起来是快的飞起,因为中间会跳出,而且要考虑的点没那么多。
注意:窝不知道为什么。。。在函数体内部开了100个int的数组内存就爆炸了。。。委屈。。
第二:这个题不需要高精度,蛮好的。。。
还有,我真该回高中好好学学再上大学,现在只能顽强地被高中生摩擦。
Code:
#include<bits/stdc++.h>
using namespace std;
const int MAX = 105;
int mp[MAX][MAX];
vector<pair <int,int>> Color[MAX*MAX];
int m,n;
void input(){
for (int i=0;i<=n*m;i++){
Color[i].clear();
}
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++){
for (int j = 1;j<=m;j++){
scanf("%d",&mp[i][j]);
Color[mp[i][j]].push_back(make_pair(i,j));
}
}
for (int i=0;i<=n*m;i++){
if (!Color[i].empty()){
sort(Color[i].begin(),Color[i].end());
}
}
}
vector<int> yIndex[MAX];
int bottom[MAX];
long long calc(int col){
// cout<<"Calcing "<<col<<endl;
memset(bottom,0,sizeof(bottom));
for (int i = 1;i<=n;i++){
yIndex[i].clear();
}
long long ans = 0;
for (auto now:Color[col]){
int ni = now.first,nj = now.second;
// cout<<"Looping "<<ni<<","<<nj<<endl;
for (int i = 1;i<=m;i++){
yIndex[i].clear();
}
for (int i = 1;i<=m;i++){
if (bottom[i]){
yIndex[bottom[i]].push_back(i);
}
}
int yl=1,yr=m;
bool br = false;
for (int ii = ni;ii>=1;ii--){
// cout<<"Findind "<<ii<<endl;
for (vector<int>::iterator it = yIndex[ii].begin();it!=yIndex[ii].end();it++){
int yy = *it;
if (yy<nj){
yl = max(yl,yy+1);
}else if (yy>nj){
yr = min (yr,yy-1);
}else{
br = true;
break;
}
}
if (br){
break;
// cout<<"Finding Break"<<endl;
}
ans+=(n-ni+1)*(nj-yl+1)*(yr-nj+1);
// cout<<"Finding End:With"<<(n-ii+1)*(nj-yl+1)*(yr-nj+1)<<endl;
}
bottom[nj] = ni;
}
return ans;
}
double work(){
long long ans = 0;
for (int i = 0;i<=n*m;i++){
if (!Color[i].empty()){
ans +=calc(i);
}
}
long long num = n*(n+1)*m*(m+1)/4;//多谢UFO___给我提出的改进建议
// for (int i=1;i<=n;i++){
// for (int j = 1;j<=m;j++){
// num+=i*j;
// }
// }
double anss = ((double)ans)/num;
// cout<<"ANS:"<<ans<<" NUM:"<<num<<" Return:"<<anss<<endl;
return anss;
}
int main(){
int Cas;
scanf("%d",&Cas);
while (Cas--){
input();
printf("%.9f\n",work());
}
return 0;
}