Problem
Input
Output
Hint
Solution
乍一看,外加无穷级数的期望裸题,估计分分钟切掉。
细一想,好像有什么不对——我推不出到底怎么列方程。我试了好几个方法,但在手玩了几个数据之后,发现都是错的。
看了看那惜字如金的题解,绞尽脑汁地理解了一番,终于懂了。
首先,我们可以设点(i,j)到点(n,m)的期望距离为
Xi,j
X
i
,
j
,那么便可以列出n*m个方程。具体来说,若有一个2*2的矩阵如下图所示:
我们则可以列出如下的方程:
1.
a=12(b+1)+12(c+1)
a
=
1
2
(
b
+
1
)
+
1
2
(
c
+
1
)
2.
b=12(a+1)+12(d+1)
b
=
1
2
(
a
+
1
)
+
1
2
(
d
+
1
)
3.
c=12(a+1)+12(d+1)
c
=
1
2
(
a
+
1
)
+
1
2
(
d
+
1
)
4.
d=0
d
=
0
对于方程1而言,你在点a的时候,有一半的概率去点b,又因为还需要走1步才能到点b,所以导致了这个式子出现在其中:
12(b+1)
1
2
(
b
+
1
)
。
于是上高斯消元。
至于高斯消元,我们可以设一个行指针p,指向我们目前还未搞过的方程的首个方程,然后再枚举一个列指针i,指向第i个项。我们先从p到n枚举一个方程q,如果遇到某个方程q的第i个项的系数有值则break。
如果有这样的方程q,我们就把它与p方程换个位,然后从p+1到n枚举一个方程j,如果遇到某个方程j的第i个项的系数有值,那么就让它乘上某个数,再减去此时的p方程,把方程j的第i个项的系数消灭。最后再p++。
如果没有这样的方程q,说明第i个元是自由元,可以为任意值,我们就随便把它赋个值。注意,此时我们并不p++,而是继续在p方程中搞事情。
这样的话我们就可以造出一个系数呈倒三角的矩阵。
由于我们扫了一遍所有的列,所以此时最后一个有系数的方程定然只有一个项有系数,而上面的方程则至多比它多一项有系数。那么我们就可以回代求出所有元的值。
于是我们就可以用
O(n3m3)
O
(
n
3
m
3
)
的时间计算出
X1,1
X
1
,
1
。
这显然会T,但是TJ说:“考虑到这个方程比较稀疏,因此消元可以很快。事实上可以证明是
n2m2
n
2
m
2
级别的。”但是我不相信它的屁话。反正可以更快。
令第一行第k个数走到(n,m)的期望步数为
Xk
X
k
,那么对于每一个位置,我们都可以根据它上面位置的方程解出它关于
Xi
X
i
的多项式。若有一个3*3的矩阵如下图所示:
则我们可以逐行求出每个位置用
Xi
X
i
来表示是怎样的。假设我们求完了第二行,在求点g,我们可以根据其上面的点d的方程来做:
d=13(a+1)+13(e+1)+13(g+1)
d
=
1
3
(
a
+
1
)
+
1
3
(
e
+
1
)
+
1
3
(
g
+
1
)
。将其化简得:
g=3∗d−a−e−3
g
=
3
∗
d
−
a
−
e
−
3
。由于我们已经求得了这个式子里面除了点g以外所有点用
Xi
X
i
的表示方法,我们就可以直接乘、加、减。
最后,我们满足了所有x坐标<n的点的方程,但是没有满足最后一行的点的方程。那么我们就列出最后一行的点的方程(当然,用
Xi
X
i
来表示),然后上高斯消元。
你问我为什么能跑过?我也不知道。
时间复杂度:
O(m3)
O
(
m
3
)
。
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 11
#define M 1001
#define db double
#define Lf long db
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
const int v[4][2]={{0,0},{-1,0},{0,1},{0,-1}};
int T,t,i,j,k,l,n,m,x,y,p,q;
db f[N][M][M];
Lf tmp,ma[M][M],a[M];
bool bz;
db xs(int x,int y)
{
return (n*m>1)+(min(n,m)>1)+(x>1&&x<n)+(y>1&&y<m);
}
int main()
{
freopen("walk.in","r",stdin);
freopen("walk.out","w",stdout);
scanf("%d",&T);
fo(t,1,T)
{
scanf("%d%d",&n,&m);
memset(f,0,sizeof f);
fo(i,1,m)f[1][i][i]=1;
fo(i,1,n-1)
fo(j,1,m)
{
f[i+1][j][0]=-xs(i,j);
fo(k,0,3)
{
x=i+v[k][0];
y=j+v[k][1];
if(!x||!y||y>m)continue;
tmp=k?-1:xs(i,j);
f[i+1][j][0]+=f[x][y][0]*tmp;
fo(l,1,m)f[i+1][j][l]+=f[x][y][l]*tmp;
}
}
memset(ma,0,sizeof ma);
fo(i,1,m-1)
{
ma[i][0]=1;
fo(k,0,3)
{
x=n+v[k][0];
y=i+v[k][1];
if(!x||!y||y>m)continue;
tmp=k?-1/xs(n,i):1;
ma[i][0]-=f[x][y][0]*tmp;
fo(j,1,m)ma[i][j]+=f[x][y][j]*tmp;
}
}
fo(i,0,m)ma[m][i]=f[n][m][i];
ma[m][0]=-ma[m][0];
p=1;
memset(a,0,sizeof a);
fo(i,1,m)
{
fo(q,p,m)
if(ma[q][i])
break;
if(q<=m)
{
fo(j,1,m)swap(ma[p][j],ma[q][j]);
fo(j,p+1,m)
if(ma[j][i])
{
tmp=ma[p][i]/ma[j][i];
fo(l,0,m)ma[j][l]=ma[j][l]*tmp-ma[p][l];
}
p++;
continue;
}
a[i]=0;
fo(j,1,p-1)ma[j][i]=0;
}
bz=0;
fd(i,m,1)
{
fo(p,1,m)
if(ma[i][p])
{
bz=1;
break;
}
if(bz)break;
}
p=m;
fd(i,i,1)
{
while(!ma[i][p])p--;
a[p]=ma[i][0]/ma[i][p];
fo(j,1,i-1)
if(ma[j][p])
{
ma[j][0]-=a[p]*ma[j][p];
ma[j][p]=0;
}
}
printf("%.0Lf\n",a[1]);
}
}