题意
n<=5000
分析
怎么现在都喜欢出这种题的数树排列数
- 把排列看作圆排列,最后答案*n。
- 考虑排列dp的一般姿势,设
f
[
i
]
f[i]
f[i]表示子树i分成一定块数的圆排列个数。
块与块之间必须要插入子树外的点。
容易发现块数是由其连出去的那条边决定的。 - 转移的话就考虑做一个类似背包的dp,使用你的数数技能即可。
- 方案数就是,你有A块和B块,你要合并成C块,同颜色之间不可合并。
将b放入a的空隙,枚举k表示有k个空隙有至少一个b块:
这个方案数大概是
∑
k
C
(
b
−
1
,
k
−
1
)
C
(
a
,
k
)
C
(
2
k
,
a
+
b
−
c
)
\sum_{k} C(b-1,k-1)C(a,k)C(2k,a+b-c)
∑kC(b−1,k−1)C(a,k)C(2k,a+b−c),当然也可以预处理。
初始时每个点是1块,在1号点可以合并成0块,也就是一个环
- 但是这样只能做 O ( n 3 ) O(n^3) O(n3),看起来很可优化但是我并不会…
- 于是我们套上个容斥。这里的限制是同种类块不能互相合并。是一个经典的容斥问题。
- 假如一颗子树里本来有x块,我们可以视作y(y<=x)块(也就是至少合并了x-y次同种块),这样做的系数是 C ( x , y ) ( − 1 ) x − y C(x,y)(-1)^{x-y} C(x,y)(−1)x−y。
- 将这些块进行排列,然后合并一些块,就完成了容斥过程。
- 具体来说,就是先背包,求出 f ( x ) f(x) f(x)表示有x块的容斥和,然后再对每一个f(x)对其合并成要求块数的容斥和求和。即求出了当前点的答案。
- 这里比较难理解的地方就是容斥的“合并”和原本的“合并”,注意区分这两者的作用就可以了。
O ( n 2 ) O(n^2) O(n2)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 5000+10;
int n,mo;
ll jc[2*N],njc[2*N];
int final[N],nex[N*2],to[N*2],tot,w[N*2],sz[N];
void link(int x, int y, int v){
to[++tot] = y, nex[tot] = final[x], final[x] = tot;
w[tot] = v;
}
ll ksm(ll x, ll y) {
ll ret = 1; for (; y; y >>= 1) {
if (y & 1 ) ret = ret * x % mo;
x = x * x % mo;
}
return ret;
}
int fa[N],f[N][N],g[N];
inline void add(int &a,int b){
a+=b;if(a>=mo)a-=mo; else
if(a<=-mo)a+=mo;
}
ll C(ll n, ll m) {
if(n<m)return 0;
return jc[n]*njc[m]%mo*njc[n-m]%mo;
}
int qiu[N];
void dp(int x){
sz[x]=1;f[x][1]=1;
static int tmp[N];
for(int i=final[x];i;i=nex[i]){
int y=to[i];if(y==fa[x])continue;
int b=w[i]>>1;
fa[y]=x; qiu[y]=b;
dp(y);
memset(tmp,0,sizeof tmp);
if(b>sz[y]){
printf("0"); exit(0);
}
for(int i=1;i<=b;i++) f[y][i]=g[y]*(((b-i)&1)==0?1:-1)*C(b,i)%mo;
for(int a=1;a<=sz[x];a++){
for(int j=1;j<=b;j++){
add(tmp[a+j],(ll)j*f[x][a]%mo*f[y][j]%mo*C(a+j-1,j)%mo);
}
}
memcpy(f[x],tmp,sizeof tmp);
sz[x]+=sz[y];
}
for(int a=qiu[x];a<=sz[x];a++){
add(g[x],f[x][a]*C(a,qiu[x])%mo);
}
// cout << "node " << x << endl;
// for(int i = 0; i <= sz[x]; i++) printf("%lld ", f[x][i]);printf("\n");
}
int main() {
freopen("permutation.in","r",stdin);
// freopen("permutation.out","w",stdout);
cin >> n >> mo;
for(int i = 1; i < n; i++) {
int x, y, v; scanf("%d %d %d", &x, &y, &v);
link(x, y, v), link(y, x, v); if (v&1){
printf("0"); return 0;
}
}
jc[0]=1;for(int i=1;i<=2*n;i++)jc[i]=jc[i-1]*i%mo;
njc[2*n]=ksm(jc[2*n],mo-2);
for(int i=2*n-1;~i;i--)njc[i]=njc[i+1]*(i+1)%mo;
dp(1);
cout<<(ll)n*((g[1]+mo)%mo)%mo<<endl;
}