题目描述
梦游中的你来到了一棵 N 个节点的树上. 你一共做了 Q 个梦, 每个梦需要你从点 u 走到 点 v 之后才能苏醒, 由于你正在梦游, 所以每到一个节点后,你会在它连出去的边中等概率地 选择一条走过去, 为了确保第二天能够准时到校, 你要求出每个梦期望经过多少条边才能苏 醒. 为了避免精度误差, 你要输出答案模10^9 + 7的结果.
题目分析
这是经典的期望问题之“无限方案”问题。
这种题如果直接暴力DP会出错(因为有后效性),所以我们要直接玄学地解方程。
我们设f[i]表示从第i个点走向它父亲的期望值,g[i]表示从第i个点的父亲走向它的期望值。
对于f[i],我们知道其状态转移方程是:
f[x]=1c[x]+∑y∈son[x]1+f[x]+f[y]c[x]
f
[
x
]
=
1
c
[
x
]
+
∑
y
∈
s
o
n
[
x
]
1
+
f
[
x
]
+
f
[
y
]
c
[
x
]
其中c[x]表示点x连出去的边的条数
哎?左右两边都有f[x]?什么情况?不怕,我们解方程!
下面是简化过程:
f[x]=1+∑y∈son[x]f[x]+f[y]c[x] f [ x ] = 1 + ∑ y ∈ s o n [ x ] f [ x ] + f [ y ] c [ x ]
f[x]=1+c[x]−1c[x]∗f[x]+∑y∈son[x]f[y]c[x] f [ x ] = 1 + c [ x ] − 1 c [ x ] ∗ f [ x ] + ∑ y ∈ s o n [ x ] f [ y ] c [ x ]
f[x]c[x]=1+∑y∈son[x]f[y]c[x] f [ x ] c [ x ] = 1 + ∑ y ∈ s o n [ x ] f [ y ] c [ x ]
f[x]=c[x]+∑y∈son[x]f[y] f [ x ] = c [ x ] + ∑ y ∈ s o n [ x ] f [ y ]
然后,随便解都可以解得出来。
我们可以发现,对于叶子节点,f[x]=1,这就是初始化。
对于g[i],状态转移方程为:
g[x]=1c[fa[x]]+1+g[x]+g[fa[x]]c[fa[x]]+∑y∈son[fa[x]]∧y≠x1+g[x]+f[y]c[fa[x]]
g
[
x
]
=
1
c
[
f
a
[
x
]
]
+
1
+
g
[
x
]
+
g
[
f
a
[
x
]
]
c
[
f
a
[
x
]
]
+
∑
y
∈
s
o
n
[
f
a
[
x
]
]
∧
y
≠
x
1
+
g
[
x
]
+
f
[
y
]
c
[
f
a
[
x
]
]
开始简化:
g[x]=1+g[x]+g[fa[x]]c[fa[x]]+∑y∈son[fa[x]]∧y≠xg[x]+f[y]c[fa[x]] g [ x ] = 1 + g [ x ] + g [ f a [ x ] ] c [ f a [ x ] ] + ∑ y ∈ s o n [ f a [ x ] ] ∧ y ≠ x g [ x ] + f [ y ] c [ f a [ x ] ]
g[x]=1+c[fa[x]]−1c[fa[x]]∗f[x]+g[fa[x]]c[fa[x]]+∑y∈son[fa[x]]∧y≠xf[y]c[fa[x]] g [ x ] = 1 + c [ f a [ x ] ] − 1 c [ f a [ x ] ] ∗ f [ x ] + g [ f a [ x ] ] c [ f a [ x ] ] + ∑ y ∈ s o n [ f a [ x ] ] ∧ y ≠ x f [ y ] c [ f a [ x ] ]
g[x]c[fa[x]]=1+g[fa[x]]c[fa[x]]+∑y∈son[fa[x]]∧y≠xf[y]c[fa[x]] g [ x ] c [ f a [ x ] ] = 1 + g [ f a [ x ] ] c [ f a [ x ] ] + ∑ y ∈ s o n [ f a [ x ] ] ∧ y ≠ x f [ y ] c [ f a [ x ] ]
g[x]=c[fa[x]]+g[fa[x]]+∑y∈son[fa[x]]∧y≠xf[y] g [ x ] = c [ f a [ x ] ] + g [ f a [ x ] ] + ∑ y ∈ s o n [ f a [ x ] ] ∧ y ≠ x f [ y ]
这个也是随便搞搞即可
初始化:g[1]=0(显然)
最后,对于q个询问,我们倍增处理,最后答案就出来了。
代码
#include<cstdio>
#include<cstring>
using namespace std;
const int max=17;
const int mod=1000000007;
struct node{
int x,y,next;
}a[210000];int len,last[110000];
int f[110000],g[110000],c[110000],fa[110000][19];
bool bk[110000];int deep[110000],s[110000];
int fs[110000][19],gs[110000][19];
void ins(int x,int y){
a[++len].x=x;a[len].y=y;
a[len].next=last[x];last[x]=len;
}
void dfs(int x){
bk[x]=false;bool son=true;
for(int k=last[x];k;k=a[k].next){
int y=a[k].y;
if(bk[y]) son=false,fa[y][0]=x,dfs(y),s[x]=(s[x]+f[y])%mod;
}
if(son) f[x]=1;
else f[x]=(c[x]+s[x])%mod;
}
void dfs2(int x){
for(int i=1;i<=max;i++) fa[x][i]=fa[fa[x][i-1]][i-1];
for(int k=last[x];k;k=a[k].next){
int y=a[k].y;
if(fa[y][0]==x){
g[y]=(c[x]+g[x]+s[x]-f[y]+mod)%mod;
deep[y]=deep[x]+1,dfs2(y);
}
}
}
void dfs3(int x){
fs[x][0]=f[x];for(int i=1;i<=max;i++) fs[x][i]=(fs[x][i-1]+fs[fa[x][i-1]][i-1])%mod;
gs[x][0]=g[x];for(int i=1;i<=max;i++) gs[x][i]=(gs[x][i-1]+gs[fa[x][i-1]][i-1])%mod;
for(int k=last[x];k;k=a[k].next){
int y=a[k].y;
if(fa[y][0]==x) dfs3(y);
}
}
int lca(int x,int y){
int ans=0;
for(int i=max;i>=0;i--){
if(deep[fa[x][i]]>=deep[y]) ans=(ans+fs[x][i])%mod,x=fa[x][i];
if(deep[fa[y][i]]>=deep[x]) ans=(ans+gs[y][i])%mod,y=fa[y][i];
}
for(int i=max;i>=0;i--){
if(fa[x][i]!=fa[y][i]){
ans=(ans+fs[x][i])%mod;x=fa[x][i];
ans=(ans+gs[y][i])%mod;y=fa[y][i];
}
}
if(x!=y){
ans=(ans+fs[x][0])%mod;x=fa[x][0];
ans=(ans+gs[y][0])%mod;y=fa[y][0];
}
return ans;
}
int main()
{
int n,q;scanf("%d%d",&n,&q);
len=0;memset(last,0,sizeof(last));
for(int i=1;i<=n-1;i++){
int x,y;scanf("%d%d",&x,&y);
ins(x,y);ins(y,x);
}
memset(c,0,sizeof(c));
for(int i=1;i<=n;i++)
for(int k=last[i];k;k=a[k].next) c[i]++;
memset(s,0,sizeof(s));
memset(bk,true,sizeof(bk));dfs(1);
g[1]=0;deep[1]=1;deep[0]=0;dfs2(1);
dfs3(1);
for(int i=1;i<=q;i++){
int st,ed;scanf("%d%d",&st,&ed);
printf("%d\n",lca(st,ed));
}
return 0;
}