题意:
个节点的树,往每个节点上放1到
的不重复的数,问有多少种情况满足不存在某个节点,满足它的父亲节点上的数比它自己的数大1。答案对998244353取模。
思路:
先考虑容斥。设为这棵树中有
个冲突的情况数(冲突即为父节点比子节点大1)。那么答案就是
解释下就是冲突的情况父节点一定比子节点大1,那这两个点各自被另外一个点的值制约,也就是可以缩成一个点。有个冲突也就等价于1到
的全排列。
明显可以通过树上背包算出,但是背包的复杂度是
。
再让我们来看的特性。为什么
不是在
个存在的限制中找
个冲突,也就是
呢?这是因为对于单个父节点只能存在一个子节点与他有冲突。于是我们便可以想到,对于单个父节点,不存在冲突等同于在冲突情况上唯一,存在冲突则在
个子节点中任意找一个子节点与其产生冲突,情况数为
。构造多项式
,将所有点的多项式相乘,
即为
的系数,这块如果懂生成函数的话会好理解一些。这个可以通过分治NTT来实现,复杂度为
。
吐槽一下这题和20年澳门站的A好像啊,只不过澳门站的分治FFT挺容易想到的,这题比较难想到这方面。
代码:
#include<bits/stdc++.h>
#define IO ios::sync_with_stdio(false);cin.tie(0)
using namespace std;
const int P = 250010;
int a[P+10], jc[P+10];
vector<int> G[P];
namespace math {
const int MOD = 998244353;
inline int add(int x, const int y) { return x += y, x >= MOD ? x - MOD : x; }
inline int sub(int x, const int y) { return x -= y, x < 0 ? x += MOD : x; }
inline int mul(const int x, const int y) { return 1ll * x * y % MOD; }
inline int qpow(int x, int y) {
int res = 1;
for (; y; y >>= 1, x = mul(x, x)) if (y & 1) res = mul(res, x);
return res;
}
}using namespace math;
namespace NTT {
int limit;
int pr = 3; // 能用的ntt的素数原根
vector<int> A, B, rev;
void init(int siz) {
int ed = siz, bit = -1;
for (limit = 1; limit <= ed; limit <<= 1) ++bit;
A.resize(limit); B.resize(limit); rev.resize(limit);
for (int i = 0; i < limit; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit);
}
void ntt(vector<int>& P, int op) {
for (int i = 0; i < limit; ++i) {
if (i < rev[i])swap(P[i], P[rev[i]]);
}
for (int mid = 1; mid < limit; mid <<= 1) {
int euler = qpow(pr, (MOD - 1) / (mid << 1));
if (op < 0) euler = qpow(euler, MOD - 2);
for (int i = 0, pos = mid << 1; i < limit; i += pos) {
int wk = 1;
for (int j = 0; j < mid; ++j, wk = mul(wk, euler)) {
int x = P[i + j], y = mul(wk, P[i + j + mid]);
P[i + j] = add(x, y), P[i + j + mid] = add(x, MOD - y);
}
}
}
if (op > 0) return;
int inv = qpow(limit, MOD - 2);
for (int i = 0; i < limit; ++i) P[i] = mul(P[i], inv);
}
void work() {
ntt(A, 1), ntt(B, 1);
for (int i = 0; i < limit; ++i) A[i] = mul(A[i], B[i]);
ntt(A, -1);
}
};
template <class T=int> T rd()
{
T res=0;T fg=1;
char ch=getchar();
while(!isdigit(ch)) {if(ch=='-') fg=-1;ch=getchar();}
while( isdigit(ch)) res=(res<<1)+(res<<3)+(ch^48),ch=getchar();
return res*fg;
}
void init(){
jc[0]=1;
for(int i=1;i<=P;i++)
jc[i]=mul(jc[i-1],i);
}
void dfs(int u,int fa){
for(auto v:G[u]){
if(v==fa)continue;
dfs(v,u);
++a[u];
}
}
void sol(int l,int r,vector<int>& v){
if(l==r){
v = {1,a[l]};
return;
}
int mid=(l+r)>>1;
vector<int> v1,v2;
sol(l,mid,v1);
sol(mid+1,r,v2);
NTT::A=v1;NTT::B=v2;
NTT::init(r-l+1);NTT::work();
v = NTT::A;
}
int main(){
IO;
init();
int n=rd();
for(int i=1;i<n;i++){
int x=rd(),y=rd();
G[x].emplace_back(y);
G[y].emplace_back(x);
}
vector<int> pol;
dfs(1,0);
sol(1,n,pol);
int ans=0;
for(int i=0,fac=1;i<n;++i){
if(fac==1)ans=add(ans,mul(pol[i],jc[n-i]));
else ans=sub(ans,mul(pol[i],jc[n-i]));
fac^=1;
}
cout<<ans<<endl;
}