本来毫无思路,看到网上设了二维数组表示期望后,觉得很妙。
设f[i][j]为聪聪在i,可可在j的聪聪吃到可可的时间期望。
dis[i][j]==1 ||==2 时,f[i][j]=1 ,i==j 时,f[i][j]=0;
由于聪聪能走2步,可可每次只能走一步,所以dis[i][j]一定是越来越小的。
我们可以先处理出g[i][j]表示此时聪聪应该往哪走。
最后就按照期望的定义求就行了。dis[i][j]大的由dis[u][v]小的得到,记忆化一下就行了。
如果有需要的话,可以求出整张图中任意点对的f[i][j]值,我之前也准备这么写,后来发现可以但没必要。
所以vector<path>a本来是要先求出近的点对的值,再求远的点对的值,就这样DP,后来就没用了。。。
#include<bits/stdc++.h>
#define maxl 1010
using namespace std;
int n,m,cnt,cat,mouse;;
double ans;
int ehead[maxl],du[maxl];
int dis[maxl][maxl],g[maxl][maxl];
struct path
{
int u,v;
};
vector <path> a[maxl];
struct ehead
{
int to,nxt;
}e[maxl*2];
double f[maxl][maxl];
struct node
{
int l,ind;
bool operator >(const node &b)const
{
return l>b.l;
}
};
priority_queue<node,vector<node>,greater<node> > q;
bool in[maxl];
inline void add(int u,int v)
{
e[++cnt].to=v;e[cnt].nxt=ehead[u];ehead[u]=cnt;
}
inline void prework()
{
scanf("%d%d",&n,&m);
scanf("%d%d",&cat,&mouse);
int u,v;
for(int i=1;i<=m;i++)
{
scanf("%d%d",&u,&v);
du[u]++;du[v]++;
add(u,v);add(v,u);
}
}
inline void dijstra(int st)
{
while(!q.empty())
q.pop();
memset(in,false,sizeof(in));
in[st]=true;
dis[st][st]=0;
q.push(node{0,st});
int u,v;node d;
while(!q.empty())
{
do
{
d=q.top();q.pop();
}while(dis[st][d.ind]!=d.l);
u=d.ind;in[u]=true;
for(int i=ehead[u];i;i=e[i].nxt)
{
v=e[i].to;
if(!in[v] && dis[st][v]>dis[st][u]+1)
{
dis[st][v]=dis[st][u]+1;
q.push(node{dis[st][v],v});
}
}
}
for(int v=1;v<=n;v++)
{
if(dis[st][v]>n)
continue;
if(v==st)
f[st][v]=0;
if(dis[st][v]==1 || dis[st][v]==2)
f[st][v]=1;
if(dis[st][v]>2)
a[dis[st][v]].push_back(path{st,v});
}
}
inline double solve(int cat,int mouse)
{
if(f[cat][mouse]<n*n*n)
return f[cat][mouse];
if(cat==mouse)
return 0.0;
if(dis[cat][mouse]<=2)
return 1.0;
double ret=1,p=1.0/(du[mouse]+1);
int v=mouse,u=g[g[cat][v]][v];
ret+=p*solve(u,v);
for(int i=ehead[mouse];i;i=e[i].nxt)
{
v=e[i].to;
ret+=p*solve(u,v);
}
f[cat][mouse]=ret;
return ret;
}
inline void mainwork()
{
memset(dis,0x3f,sizeof(dis));
for(int i=1;i<=n;i++)
dijstra(i);
memset(g,0x3f,sizeof(g));
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
{
f[i][j]=n*n*n;
if(i!=j && dis[i][j]<=n)
{
for(int k=ehead[i];k;k=e[k].nxt)
if(dis[e[k].to][j]<dis[i][j] && e[k].to<g[i][j])
g[i][j]=e[k].to;
}
}
for(int i=1;i<=n;i++) g[i][i]=i;
ans=solve(cat,mouse);
}
inline void print()
{
printf("%.3f",ans);
}
int main()
{
prework();
mainwork();
print();
return 0;
}