题意:
树上随机游走,每个节点有权值vi,若存在某个房间未经过2次就走过去,否则停下,最终从S开始走权值的期望
n 1e6
样例
Input
2
5 8
1 2
1
Output
8
发现走过的路形态不多
前两种走完之后蓝点就被堵死了,第三四就停下了
思路就很简单了
设f[x]表示走到x的子树不能回头了,最后的期望
g[x]表示往上到x,并且下一步要往上走或去其他子树或者死在上面,最后的期望
转移貌似有点多
#include<cstring>
#include<cstdio>
#define N 1000100
#define mo 1000000007
#define ll long long
using namespace std;
void read(int &n){
n=0;int c;for(;(c=getchar())>57||c<48;);
for(;c>47&&c<58;c=getchar())n=n*10+c-48;
}
int n,ov[N],bl[N],s,que[N],h,t,p[N],fir[N],nex[N+N],to[N+N],out[N],d[N],top;
int fa[N],t1[N],ny[N],pr[N],ipr[N],pr_[N],ipr_[N],ans,v[N];
#define lnk(x,y) to[++top]=y,nex[top]=fir[x],fir[x]=top,++d[x]
ll qpow(ll a,ll i){
ll r=1;for(;i;i>>=1,a=a*a%mo)if(i&1)r=r*a%mo;return r;
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;++i)read(v[i]);
for(int i=1;i<n;++i){
int x,y;read(x);read(y);
lnk(x,y);lnk(y,x);
}scanf("%d",&s);
ny[0]=1;for(int i=1;i<=n;++i)ny[i]=1ll*ny[i-1]*i%mo;
int NY=qpow(ny[n],mo-2);
for(int i=n;i;--i){
ny[i]=1ll*NY*ny[i-1]%mo;NY=1ll*NY*i%mo;
}
pr[0]=ipr[0]=pr_[0]=ipr_[0]=1;
for(que[t=1]=s;h^t;)
for(int x=que[++h],i=fir[x],y;i;i=nex[i])
if((y=to[i])^fa[x])fa[que[++t]=y]=x,++out[x];
for(int i=1;i<=n;++i){
int x=que[i];
pr[x]=1ll*pr[fa[x]]*d[x]%mo;ipr[x]=1ll*ipr[fa[x]]*ny[d[x]]%mo;
pr_[x]=1ll*pr_[fa[x]]*(d[x]-1)%mo;ipr_[x]=1ll*ipr_[fa[x]]*ny[d[x]-1]%mo;
if(!pr_[x])pr_[x]=1;
}
for(int i=n;i;--i){
int x=que[i];t1[x]=ipr[x];
bl[fa[x]]=(bl[fa[x]]+1ll*ipr[x]*ipr_[fa[fa[x]]]%mo*ny[d[fa[x]]])%mo;
for(int o=fir[x];o;o=nex[o])if(to[o]^fa[x])bl[x]=(bl[x]+bl[to[o]])%mo;
}p[s]=1;
for(int i=1;i<=n;++i){
int x=que[i],sm=0,sm2=0;ov[x]=(ov[fa[x]]+1ll*p[x]*pr[x]%mo*ny[out[x]])%mo;
for(int o=fir[x];o;o=nex[o])if(to[o]^fa[x])sm=(sm+bl[to[o]])%mo,sm2=(sm2+t1[to[o]])%mo;
for(int y,o=fir[x];o;o=nex[o])if((y=to[o])^fa[x]){
p[y]=(p[y]+(1ll*ov[fa[x]]*ny[d[x]-1]+1ll*p[x]*ny[out[x]]%mo*pr[x]%mo*ny[out[x]-1])%mo*pr_[x]%mo*(sm+mo-bl[y]))%mo;
p[y]=(p[y]+(1ll*ov[fa[x]]*ny[d[x]]+1ll*p[x]*ny[out[x]]%mo*pr[x]%mo*ny[out[x]])%mo*(sm2+mo-t1[y]))%mo;
p[y]=(p[y]+1ll*ov[fa[fa[x]]]*ipr[x]%mo*ny[d[x]-1]%mo*ny[d[fa[x]]])%mo;
p[y]=(p[y]+1ll*p[fa[x]]*ny[out[fa[x]]]%mo*ny[out[fa[x]]]%mo*ny[d[x]-1]%mo*ny[d[x]])%mo;
}
}
for(int i=1;i<=n;++i)if(!out[i]){
int x=fa[i];
ans=(ans+1ll*p[i]*v[i]+1ll*ov[fa[x]]*ipr[i]%mo*ny[d[x]]%mo*v[i]+1ll*p[x]*ny[out[x]]%mo*ny[out[x]]%mo*ny[d[i]]%mo*v[i])%mo;
}
for(int i=1;i<=n;++i)if(out[i]==1){
for(int o=fir[i],y;o;o=nex[o])if((y=to[o])^fa[i])
ans=(ans+1ll*p[i]*bl[y]%mo*pr[i]%mo*pr_[i]%mo*ny[out[i]]%mo*v[i])%mo;
}printf("%d",ans);
}