E - Directed Tree
题目描述
给出一棵 n n n 个节点的树,求满足以下要求的 1 ∼ n 1\sim n 1∼n 的排列 a a a 的个数:
- 对于每个 i i i ,保证 a i a_i ai 不能是除 i i i 以外的任何一个 i i i 的祖先。
数据范围与提示
1 ≤ n ≤ 2000 1\le n\le 2000 1≤n≤2000 。
前言
考试时没做到这题。考后先自己想了半个小时,然后想了一个解法感觉没毛病,交上去居然过了。
思路
每个节点上的数字不能是该节点的祖先,转化过后就变成每个数字不能出现在对应节点的子孙节点上。
转化过后,如果我们把整棵树拍在DFS序上,那么每个值不能出现的位置是一段区间。这已经强烈暗示我们用容斥了。
设
S
(
i
)
S(i)
S(i) 表示至少有
i
i
i 个值出现的位置不合法的方案数(注意不是排列数),那么有
A
n
s
=
∑
i
=
0
n
(
−
1
)
i
⋅
S
(
i
)
⋅
(
n
−
i
)
!
Ans=\sum_{i=0}^n(-1)^i\cdot S(i)\cdot (n-i)!
Ans=i=0∑n(−1)i⋅S(i)⋅(n−i)!
由于父节点和儿子节点的不合法区间有包含关系,而兄弟节点的不合法区间不交,所以很容易想到用树形DP做:
设
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j] 表示节点
i
i
i 的子树中有
j
j
j 个点不合法的方案数,那么先考虑节点
i
i
i 合法的情况:
d
p
′
[
i
]
[
j
]
=
∑
k
1
,
k
2
,
k
3
,
.
.
.
k
1
+
k
2
+
k
3
+
.
.
.
=
j
d
p
[
s
o
n
1
]
[
k
1
]
∗
d
p
[
s
o
n
2
]
[
k
2
]
∗
d
p
[
s
o
n
3
]
[
k
3
]
∗
.
.
.
dp'[i][j]=\sum_{k_1,k_2,k_3,...}^{k_1+k_2+k_3+...=j}dp[son_1][k_1]*dp[son_2][k_2]*dp[son_3][k_3]*...
dp′[i][j]=k1,k2,k3,...∑k1+k2+k3+...=jdp[son1][k1]∗dp[son2][k2]∗dp[son3][k3]∗...
其中
s
o
n
j
son_j
sonj 表示节点
i
i
i 的第
j
j
j 个儿子。
然后再把节点
i
i
i 不合法的情况算进来:
d
p
[
i
]
[
j
]
=
d
p
′
[
i
]
[
j
]
+
d
p
′
[
i
]
[
j
−
1
]
∗
(
s
i
z
[
i
]
−
j
)
dp[i][j]=dp'[i][j]+dp'[i][j-1]*(siz[i]-j)
dp[i][j]=dp′[i][j]+dp′[i][j−1]∗(siz[i]−j)其中
s
i
z
[
i
]
siz[i]
siz[i] 表示节点
i
i
i 的子树的大小。注意上式中
s
i
z
[
i
]
−
j
siz[i]-j
siz[i]−j 其实是
(
s
i
z
[
i
]
−
1
)
−
(
j
−
1
)
(siz[i]-1)-(j-1)
(siz[i]−1)−(j−1) 化简的结果,因为
i
i
i 能填的不合法位置有
s
i
z
[
i
]
−
1
siz[i]-1
siz[i]−1 个,而当前已填了
j
−
1
j-1
j−1 个。
容易发现此时
S
(
i
)
=
d
p
[
1
]
[
i
]
S(i)=dp[1][i]
S(i)=dp[1][i] ,所以答案的计算式为
A
n
s
=
∑
i
=
0
n
(
−
1
)
i
⋅
d
p
[
1
]
[
i
]
⋅
(
n
−
i
)
!
Ans=\sum_{i=0}^n(-1)^i\cdot dp[1][i]\cdot (n-i)!
Ans=i=0∑n(−1)i⋅dp[1][i]⋅(n−i)!这里的树形DP无论是从计算式看还是从代码看都是
O
(
n
3
)
O(n^3)
O(n3) 的,但是真的会算这么多次吗?
我们从代码中分析:
for(uns i=0;i<G[x].size();i++){
int v=G[x][i];dfs(v);
siz[x]+=siz[v];
for(int j=siz[x]-1;j>=0;j--){
ll cg=0;
for(int k=0,lim=min(j,siz[v]);k<=lim;k++)
cg=(cg+dp[x][j-k]*dp[v][k]%MOD)%MOD;
dp[x][j]=cg;
}
}
这里的 v v v 是 x x x 的一个儿子节点,观察枚举的范围,相当于是枚举 v v v 的子树中的节点和当前已算过的子树中的节点进行两两配对的复杂度。容易发现,任意一对点最多只会被枚举一次,因为枚举一次后就会被合并在同一个子树下和其它子树配对。因此此过程均摊复杂度 O ( n 2 ) O(n^2) O(n2) 。
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<stack>
#include<ctime>
#include<map>
#define ll long long
#define MAXN 2005
#define uns unsigned
#define INF 0x7f7f7f7f
#define MOD 998244353ll
#define lowbit(x) ((x)&(-(x)))
using namespace std;
inline ll read(){
ll x=0;bool f=1;char s=getchar();
while((s<'0'||s>'9')&&s>0){if(s=='-')f^=1;s=getchar();}
while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+s-'0',s=getchar();
return f?x:-x;
}
int n,m,fa[MAXN],siz[MAXN];
vector<int>G[MAXN];
ll dp[MAXN][MAXN],ans,fac[MAXN];
inline void dfs(int x){
siz[x]=1,dp[x][0]=1;
for(uns i=0;i<G[x].size();i++){
int v=G[x][i];dfs(v);
siz[x]+=siz[v];
for(int j=siz[x]-1;j>=0;j--){
ll cg=0;
for(int k=0,lim=min(j,siz[v]);k<=lim;k++)
cg=(cg+dp[x][j-k]*dp[v][k]%MOD)%MOD;
dp[x][j]=cg;
}
}
for(int j=siz[x]-1;j>0;j--)
dp[x][j]=(dp[x][j]+dp[x][j-1]*(siz[x]-j)%MOD)%MOD;
}
int main()
{
n=read();
for(int i=2;i<=n;i++)G[fa[i]=read()].push_back(i);
fac[0]=fac[1]=1;
for(int i=2;i<=n;i++)fac[i]=fac[i-1]*i%MOD;
dfs(1);
for(int i=0;i<=n;i++){
if(i&1)ans=(ans-dp[1][i]*fac[n-i]%MOD+MOD)%MOD;
else ans=(ans+dp[1][i]*fac[n-i]%MOD)%MOD;
}
printf("%lld\n",ans);
return 0;
}