dfs两次, dp出每个点作为最后一个点的方案数。
#include<bits/stdc++.h> #define LL long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ull unsigned long long using namespace std; const int N = 1e5 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 998244353; const double eps = 1e-8; const double PI = acos(-1); int n, ans, son[N], dp[N]; vector<int> G[N]; int F[N], Finv[N], inv[N]; void init() { inv[1] = F[0] = Finv[0] = 1; for(int i = 2; i < N; i++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod; for(int i = 1; i < N; i++) F[i] = 1ll * F[i - 1] * i % mod; for(int i = 1; i < N; i++) Finv[i] = 1ll * Finv[i - 1] * inv[i] % mod; } int Power(int a, int b) { int ans = 1; while(b) { if(b & 1) ans = 1ll * ans * a % mod; a = 1ll * a * a % mod; b >>= 1; } return ans; } void dfs(int u, int fa) { dp[u] = 1; for(auto& v : G[u]) { if(v == fa) continue; dfs(v, u); dp[u] = 1ll * dp[u] * dp[v] % mod; dp[u] = 1ll * dp[u] * Finv[son[v]] % mod; son[u] += son[v]; } dp[u] = 1ll * dp[u] * F[son[u]] % mod; son[u]++; } void getAns(int u, int fa, int tmp) { ans = (ans + 1ll * tmp * dp[u] % mod * F[n - 1] % mod * Finv[son[u] - 1] % mod * Finv[n - son[u]] % mod) % mod; tmp = 1ll * tmp * Finv[n - son[u]] % mod; for(auto& v : G[u]) { if(v == fa) continue; tmp = 1ll * tmp * dp[v] % mod; tmp = 1ll * tmp * Finv[son[v]] % mod; } for(auto& v : G[u]) { if(v == fa) continue; int nxttmp = 1ll * tmp * F[son[v]] % mod * F[n - son[v] - 1] % mod * Power(dp[v], mod - 2) % mod; getAns(v, u, nxttmp); } } int main() { init(); scanf("%d", &n); for(int i = 2; i <= n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } dfs(1, 0); getAns(1, 0, 1); printf("%d\n", ans); return 0; } /* */