Description
给出一棵 n n 个节点的无向树,可以任意固定树根来整棵树得到这棵树的所有 dfs d f s 序,给出一排列 B B ,问这棵树字典序小于的 dfs d f s 序个数
Input
第一行一整数 T T 表示用例组数,每组用例首先输入一整数表示点数,之后输入 n n 个整数表示排列,最后 n−1 n − 1 行输入树边
(∑n≤106) ( ∑ n ≤ 10 6 )
Output
输出字典序小于 B B 的个数,结果模 109+7 10 9 + 7
Sample Input
2
5
2 1 3 5 4
1 2
2 3
2 4
4 5
Sample Output
3
9
Solution
假设树根为 u u ,那么从开始的 dfs d f s 序方案数为 du(u)∏v≠u(du(v)−1) d u ( u ) ∏ v ≠ u ( d u ( v ) − 1 ) ,首先统计根节点小于 b1 b 1 的 dfs d f s 序个数,之后以 b1 b 1 为根,依次确定前 i i 位与序列相同,在第 i+1 i + 1 位小于 bi+1 b i + 1 的 dfs d f s 序个数计数即可,每次从一个节点走到一个确定的儿子节点对方案数的影响是除以当前节点的儿子数量,计数过程中会需要统计一个节点的儿子节点中编号小于 b b 序列某元素的值,且每次选择一个儿子往下走的时候需要删掉这个儿子节点,故需要用维护一下每个点的儿子集合
Code
#include<cstdio>
#include<ctime>
#include<algorithm>
using namespace std;
typedef long long ll;
const int INF=0x3f3f3f3f,maxn=1000006;
#define mod 1000000007
template<class Tdata>
class Treap {
public:
#define Treap_size 2000010 /* 记得改这个东西 */
#define NotFound (Tdata)-1
struct node
{
int l,r,num,son,rd;Tdata val;
}t[Treap_size];
int treesz,root[maxn];
void update(int k){t[k].son=t[t[k].l].son+t[t[k].r].son+t[k].num;}
void right_rotate(int &k){int p=t[k].l;t[k].l=t[p].r;t[p].r=k;update(k);update(p);k=p;}
void left_rotate(int &k){int p=t[k].r;t[k].r=t[p].l;t[p].l=k;update(k);update(p);k=p;}
void Insert(int &k,Tdata x)
{
if(!k){k=++treesz;t[k].val=x;t[k].rd=rand();t[k].son=t[k].num=1;t[k].l=t[k].r=0;return ;}
t[k].son++;
if(x==t[k].val){t[k].num++;return ;}
if(x<t[k].val){Insert(t[k].l,x);if(t[t[k].l].rd<t[k].rd)right_rotate(k);}
if(x>t[k].val){Insert(t[k].r,x);if(t[t[k].r].rd<t[k].rd)left_rotate(k);}
}
bool Delete(int &k,Tdata x)
{
if(!k)return 0;
if(x==t[k].val)
{
if(t[k].num>1){t[k].num--;t[k].son--;return 1;}
if(!t[k].l||!t[k].r){k=t[k].l+t[k].r;return 1;}
if(t[t[k].l].rd<t[t[k].r].rd){right_rotate(k);return Delete(k,x);}
else {left_rotate(k);return Delete(k,x);}
}
bool res;
if(x<t[k].val){res=Delete(t[k].l,x);if(res)t[k].son--;return res;}
else {res=Delete(t[k].r,x);if(res)t[k].son--;return res;}
}
Tdata get_pre(int k,Tdata x)//前驱
{
if(!k)return NotFound;
int res;
if(x>t[k].val){res=get_pre(t[k].r,x);return res==NotFound?t[k].val:res;}
return get_pre(t[k].l,x);
}
Tdata get_last(int k,Tdata x)//后继
{
if(!k)return NotFound;
int res;
if(x<t[k].val){res=get_last(t[k].l,x);return res==NotFound?t[k].val:res;}
return get_last(t[k].r,x);
}
Tdata ask_kth_small(int k,int x)//第x小
{
if(!k)return NotFound;
if(x<=t[t[k].l].son)return ask_kth_small(t[k].l,x);
if(x>t[t[k].l].son+t[k].num)return ask_kth_small(t[k].r,x-(t[t[k].l].son+t[k].num));
return t[k].val;
}
int ask_rank(int k,Tdata x)//x的排名
{
if(!k)return -1;
if(x==t[k].val)return t[t[k].l].son+1;
if(x<t[k].val)return ask_rank(t[k].l,x);
return t[t[k].l].son+t[k].num+ask_rank(t[k].r,x);
}
int FIND(Tdata x,int k)
{
if(!k)return 0;
if(x==t[k].val)return 1;
if(x<t[k].val)return FIND(x,t[k].l);
if(x>t[k].val)return FIND(x,t[k].r);
}
Treap(){treesz=0;}
void clear(int n)
{
treesz=0;
for(int i=1;i<=n;i++)root[i]=0;
}
void insert(int pos,Tdata val){if(FIND(val,root[pos])==0)Insert(root[pos],val);}//无if multiset
int erase(int pos,Tdata val){return Delete(root[pos],val);}
int size(int pos){return t[root[pos]].son;}
int empty(int pos){return t[root[pos]].son==0;}
int find(int pos,int x){return FIND(x,root[pos]);}//找不到是0,找到是1
Tdata begin(int pos){return ask_kth_small(root[pos],1);}
Tdata end(){return NotFound;}
Tdata END(int pos){return ask_kth_small(root[pos],t[root[pos]].son);}
Tdata next(int pos,Tdata x){return get_last(root[pos],x);}
Tdata pre(int pos,Tdata x){return get_pre(root[pos],x);}
int lower_bound(int pos,Tdata x){return get_last(root[pos],x-1);} // >=
int upper_bound(int pos,Tdata x){return get_last(root[pos],x);} // >
int distance(int pos,Tdata x){int res=get_pre(root[pos],x);if(res==NotFound)return 0;return ask_rank(root[pos],res);}
#undef Treap_size
#undef NotFound
};
Treap<int>S;
int mul(int x,int y)
{
ll z=1ll*x*y;
return z-z/mod*mod;
}
int add(int x,int y)
{
x+=y;
if(x>=mod)x-=mod;
return x;
}
int inv[maxn],fact[maxn];
void init(int n=1e6)
{
fact[0]=1;
for(int i=1;i<=n;i++)fact[i]=mul(i,fact[i-1]);
inv[1]=1;
for(int i=2;i<=n;i++)inv[i]=mul(mod-mod/i,inv[mod%i]);
}
int T,n,b[maxn],fa[maxn];
struct node
{
int to,next;
}g[2*maxn];
int tot,head[maxn];
void add_edge(int u,int v)
{
g[tot].to=v,g[tot].next=head[u],head[u]=tot++;
}
void dfs(int u,int f)
{
fa[u]=f;
for(int i=head[u];~i;i=g[i].next)
{
int v=g[i].to;
if(v==f)continue;
S.insert(u,v);
dfs(v,u);
}
}
int res,ans,pos,flag;
void DFS(int u)
{
if(!flag)return ;
if(pos==n)return ;
if(!S.size(u))
{
if(fa[u])DFS(fa[u]);
return ;
}
int num=S.distance(u,b[pos]);
res=mul(res,inv[S.size(u)]);
ans=add(ans,mul(num,res));
if(!S.find(u,b[pos]))
{
flag=0;
return ;
}
int v=b[pos];
pos++;
S.erase(u,v);
DFS(v);
}
int main()
{
srand(time(0));
init();
scanf("%d",&T);
while(T--)
{
scanf("%d",&n);
for(int i=1;i<=n;i++)scanf("%d",&b[i]);
S.clear(n);
tot=0;
for(int i=1;i<=n;i++)head[i]=-1;
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add_edge(u,v);add_edge(v,u);
}
dfs(b[1],0);
res=1;
for(int i=1;i<=n;i++)
if(i!=b[1])res=mul(res,fact[S.size(i)]);
else res=mul(res,fact[S.size(i)-1]);
ans=0;
for(int i=1;i<b[1];i++)ans=add(ans,mul(res,S.size(i)+1));
res=mul(res,S.size(b[1]));
pos=2;
flag=1;
DFS(b[1]);
printf("%d\n",ans);
}
return 0;
}