Task
给n个节点中有K个入口的树,求有多少个1~n的排列顺序,顺序遍历每个节点时,都存在都入口到该节点的路径,满足路径上不包含之前遍历过的点。答案mod 1e9+7。
Solution
1. K=1。
在树上求方案数,容易想到是树形dp。
以入口S作为根,发现遍历节点x之前,x的子树必须都已经被遍历过了,dp[ i ]表示,以i为根的子树都已经被遍历的方案数。X子树对应的序列为儿子们t1,t2~tn的序列的组合。保持各小序列相对位置不变,穿插组合成一个大序列,再在结尾加上x的序列。用排列组合(乘法逆元和快速幂)来实现。
- K=2序列上的问题
a) 是一条链,且S1,S2在链的两端。
定义dp[ l ][ r ]为[ l , r ]这个区间都被遍历的方案数,如果端点L,R都被遍历了,那么L,R里面的点也被遍历了。
最后一个点取的只可能是l或r.
dp[ l ][ r ]=dp[ l+1 ][ r ]+dp[ l ][ r-1 ]
特别的,当l=r时dp[ l ][ r ]=1
b) 是一棵树。
以S1,S2为两端,把树拉成一条链,链上的每一个点代表一棵树。
对于链上的每一个点,先用K=1的树形dp预处理出这棵树的方案数,dp[i].
最后一个取的只可能是L上的根或者R上的根,记为rt,但是rt下的节点,却可以与之前的序列任意组合。
DP[ l ][ R ]=C( sum( L, R )-1,sum( L,L )-1 )*DP[ L+1 ][ R ]*dp[ id [ L ] ]+
C ( sum(L,R )- 1,sum( R,R )- 1 )*DP[ L+1 ][ R ] *dp[ id [R] ]
const int M=1e5+5,N=2e3+3,P=1e9+7;
int n,m,ecnt,S1,S2;
int head[M],f[M];
struct edge{
int t,nxt;
}e[M<<1];
inline void addedge(int f,int t){
e[++ecnt]=(edge){t,head[f]};
head[f]=ecnt;
}
inline void input(){
int i,j,k,a,b;
rd(n);rd(m);
rd(S1);
if(m==2)rd(S2);
rep(i,1,n-1){
rd(a);rd(b);
addedge(a,b);
addedge(b,a);
}
}
inline void init(){
int i;
f[0]=1;
rep(i,1,n)f[i]=1ll*f[i-1]*i%P;
}
inline int fst_pow(int a,int p){
int ans=1;
while(p){
if(p&1)ans=1ll*ans*a%P;
a=1ll*a*a%P;
p>>=1;
}
return ans;
}
inline int C(int a,int b){return 1ll*f[a]*fst_pow(1ll*f[b]*f[a-b]%P,P-2)%P;}
struct P40{//树形dp[i],i子树内被填满 排序的方案数
int dp[M],sz[M];
inline void dfs(int f,int x){
dp[x]=1;
for(int i=head[x];i;i=e[i].nxt){
if(e[i].t==f)continue;
dfs(x,e[i].t);
sz[x]+=sz[e[i].t];
dp[x]=1ll*dp[x]*dp[e[i].t]%P*C(sz[x],sz[e[i].t])%P;
}
sz[x]++;
}
inline void solve(){
dfs(0,S1);
sc(dp[S1]);
}
}P40;
struct P60{
int dp[N],DP[N][N],sum[N],fa[N],id[N],sz[N];
int num;
bool mark[N];
inline void dfs1(int f,int x){
fa[x]=f;
for(int i=head[x];i;i=e[i].nxt)if(e[i].t!=f)dfs1(x,e[i].t);
}
inline void getarray(){
int i,x=S1;
while(x!=fa[S2]){
mark[x]=1;//标记序列上的点
id[++num]=x;//重新编号
x=fa[x];
}
}
inline void dfs2(int f,int x){
dp[x]=1;
for(int i=head[x];i;i=e[i].nxt){
if(e[i].t==f||mark[e[i].t])continue;
dfs2(x,e[i].t);
sz[x]+=sz[e[i].t];
dp[x]=1ll*dp[x]*dp[e[i].t]%P*C(sz[x],sz[e[i].t])%P;
}
sz[x]++;
}
inline void solve(){
int i,j,k,l,r;
dfs1(0,S2);
getarray();
rep(i,1,num){
dfs2(0,id[i]);//序列上的一个点就是一棵树
sum[i]=sum[i-1]+sz[id[i]];//前缀和
DP[i][i]=dp[id[i]];
}
rep(i,2,num)
rep(l,1,num-i+1){
r=l+i-1;
DP[l][r]=1ll*DP[l+1][r]*DP[l][l]%P*C(sum[r]-sum[l-1]-1,sum[l]-sum[l-1]-1)%P+1ll*DP[l][r-1]*DP[r][r]%P*C(sum[r]-sum[l-1]-1,sum[r]-sum[r-1]-1)%P;
if(DP[l][r]>=P)DP[l][r]-=P;
}
sc(DP[1][num]);
}
}P60;
int main(){
// freopen("0.in","r",stdin);
input();
init();
if(m==1)P40.solve();
else P60.solve();
return 0;
}