HDU 6338
题解:
首先显然的
dp
d
p
,
dp(u)=|sonu|!∏v∈sonudp(v)
d
p
(
u
)
=
|
s
o
n
u
|
!
∏
v
∈
s
o
n
u
d
p
(
v
)
,然后换根dp,求出以任意节点为根的答案。
然后考虑根编号小于
B1
B
1
的,直接加入答案。对于根编号为
B1
B
1
, 可以按为利用前面求得的
dp
d
p
值求得答案。 在递归求解的时候,需要离散化并使用树状数组维护当前儿子的编号。
注意细节!!
时间复杂度
O(nlogn)
O
(
n
l
o
g
n
)
#include<bits/stdc++.h>
#define LL long long
#define ull unsigned long long
#define ULL ull
#define mp make_pair
#define pii pair<int,int>
#define piii pair<int, pii >
#define pll pair <ll,ll>
#define pb push_back
#define big 20160116
#define INF 2147483647
#define pq priority_queue
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
namespace Mymath{
LL qp(LL x,LL p,LL mod){
LL ans=1;
while (p){
if (p&1) ans=ans*x%mod;
x=x*x%mod;
p>>=1;
}
return ans;
}
LL inv(LL x,LL mod){
return qp(x,mod-2,mod);
}
LL C(LL N,LL K,LL fact[],LL mod){
return fact[N]*inv(fact[K],mod)%mod*inv(fact[N-K],mod)%mod;
}
template <typename Tp> Tp gcd(Tp A,Tp B){
if (B==0) return A;
return gcd(B,A%B);
}
template <typename Tp> Tp lcm(Tp A,Tp B){
return A*B/gcd(A,B);
}
};
const int Maxn=1000005;
const LL mod=1e9+7;
using namespace Mymath;
int T,n;
int b[Maxn];
vector<int> G[Maxn];
LL dp[Maxn];
LL Ans[Maxn];
int siz[Maxn];
LL fact[Maxn];
LL ivf[Maxn];
int fa[Maxn];
LL sn[Maxn];
void dfs(int x,int p){
dp[x]=1;
siz[x]=1;
fa[x]=p;
LL &son=sn[x];
for (int i=0;i<G[x].size();i++){
int v=G[x][i];
if (v!=p) dfs(v,x),dp[x]=dp[x]*dp[v]%mod,son++,siz[x]+=siz[v];
}
dp[x]=dp[x]*fact[son]%mod;
//cout<<x<<' '<<dp[x]<<endl;
}
void dfs2(int x,LL dpfa,int p){
if (p==-1){
Ans[x]=dp[x];
}
else{
Ans[x]=dp[x]*ivf[sn[x]]%mod*fact[sn[x]+1]%mod;
Ans[x]=Ans[x]*dpfa%mod;
}
//cout<<x<<' '<<Ans[x]<<endl;
LL p1=1;
for (int i=0;i<G[x].size();i++){
int v=G[x][i];
if (v==p) continue;
p1*=dp[v];
p1%=mod;
}
for (int i=0;i<G[x].size();i++){
int v=G[x][i];
if (v==p) continue;
LL Rm=p1*inv(dp[v],mod)%mod;
Rm=Rm*dpfa%mod;
//cout<<Rm<<endl;
if (p==-1)Rm=Rm*fact[sn[x]-1]%mod;
else Rm=Rm*fact[sn[x]]%mod;
dfs2(v,Rm,x);
}
}
void Add(vector<int>&bit,int pos,int val){
while (pos<bit.size()){
bit[pos]+=val;
pos+=pos&-pos;
}
}
int query(vector<int>&bit,int pos){
int ret=0;
while (pos){
ret+=bit[pos];
pos-=pos&-pos;
}
return ret;
}
pair<LL,bool> dfs3(int x,int p,LL &ans,int lvl){
//cout<<x<<p<<lvl<<endl;
if (sn[x]==0){
return mp(0,1);
}
if (fa[b[lvl]]!=x){
LL fuck=1;
LL sf=sn[x]-1;
LL ff=0;
for (int i=0;i<G[x].size();i++){
int v=G[x][i];
if (v==p) continue;
fuck=fuck*dp[v]%mod;
if (v<b[lvl]) ff++;
}
//cout<<fuck<<' '<<ff<<' '<<sf<<endl;
//ans+=fuck*ff%mod*fact[sf]%mod;
//ans%=mod;
return mp(fuck*ff%mod*fact[sf]%mod,0);
}
LL dpm=1;
for (int i=0;i<G[x].size();i++){
int v=G[x][i];
if (v==p) continue;
dpm=dpm*dp[v]%mod;
}
LL nowsons=sn[x];
vector<int> bit;
int cnt=0;
vector<int> VV;
for (int i=0;i<G[x].size();i++){
int v=G[x][i];
if (v==p) continue;
VV.pb(v);
}
sort(VV.begin(),VV.end());
cnt=VV.size();
bit.resize(cnt+10,0);
for (int i=0;i<VV.size();i++){
Add(bit,i+1,1);
}
LL tmp=0;
while (1){
int pos=lower_bound(VV.begin(),VV.end(),b[lvl])-VV.begin();
LL xx=query(bit,pos);
LL delt=xx*dpm%mod*fact[nowsons-1];
delt%=mod;
tmp+=delt;
tmp%=mod;
//cout<<b[lvl]<<' '<<xx<<' '<<dpm<<' '<<nowsons-1<<endl;
if (fa[b[lvl]]!=x){
return mp((tmp)%mod,0);
}
pair<LL,bool> r=dfs3(b[lvl],x,ans,lvl+1);
if (r.second==0){
return mp((tmp+r.first*dpm%mod*inv(dp[b[lvl]],mod)%mod*fact[nowsons-1]%mod)%mod,0);
}
else{
tmp+=r.first*dpm%mod*inv(dp[b[lvl]],mod)%mod*fact[nowsons-1]%mod;
tmp%=mod;
}
//tmp+=delt;
//tmp%=mod;
nowsons--;
dpm=dpm*inv(dp[b[lvl]],mod)%mod;
Add(bit,pos+1,-1);
lvl+=siz[b[lvl]];
if (!nowsons) return mp(tmp,1);
}
}
void mian(){
n=read();
for (int i=0;i<=n;i++){
G[i].clear();dp[i]=0;sn[i]=0;
}
for (int i=1;i<=n;i++) b[i]=read();
for (int i=0;i<n-1;i++){
int u=read(),v=read();
G[u].pb(v);G[v].pb(u);
}
int rt=b[1];
dfs(rt,-1);
dfs2(rt,1,-1);
LL ans=0;
for (int i=1;i<=n;i++){
if (i<b[1]){
ans+=Ans[i];
ans%=mod;
}
}
//cout<<ans<<endl;
LL t=dfs3(rt,-1,ans,2).first;
ans+=t;
ans%=mod;
printf("%I64d\n",ans);
}
int main(){
//freopen("input.txt","r",stdin);
//freopen("output.txt","w",stdout);
fact[0]=1;
for (int i=1;i<Maxn;i++) fact[i]=fact[i-1]*i%mod;
for (int i=0;i<Maxn;i++){
ivf[i]=inv(fact[i],mod);
}
T=read();
while (T--){
mian();
}
}
/*
1
11
1 4 5 2 11 7 8 9 6 3 10
2 1
3 1
4 1
5 4
6 1
7 2
8 7
9 1
10 1
11 2
*/
/*
1
10
7 3 5 1 8 2 10 9 6 4
2 1
3 1
4 1
5 2
6 5
7 5
8 2
9 4
10 4
*/
/*
1
11
2 8 1 4 3 11 7 9 10 5 6
2 1
3 2
4 3
5 1
6 2
7 5
8 2
9 8
10 1
11 10
*/
/*72*/
/*
1
11
2 5 4 6 10 1 8 11 3 7 9
2 1
3 1
4 1
5 2
6 4
7 3
8 3
9 6
10 2
11 7
*/
/*36*/