hdu 6647Bracket Sequences on Tree
求一棵树所有可能的括号序方案数(从不同根dfs)
固定根的话,当前点
x
x
x的括号序方案数:
设儿子的子树有k类,每类有
c
i
c_i
ci个,总共tot个,那么方案数
f
[
x
]
=
(
t
o
t
c
1
,
c
2
,
.
.
.
,
c
k
)
∗
∏
i
=
1
k
f
[
s
o
n
i
]
f[x] = \binom{tot}{c_1,c_2,...,c_k} * \prod_{i = 1}^{k}f[son_i]
f[x]=(c1,c2,...,cktot)∗∏i=1kf[soni]
然后就需要一个换根dp
注意树hash并且在换根dp很好写的几种方法,这道题作为树hash 的检验是很好的。因为对每个点作为根的每个子树,如果hash冲突都会错。
这里提供两种hash方法
- h [ x ] = I N I T + ∑ i = 1 k C h [ s o n i ] , 可 取 I N I T = 1240198 , C = 35224111 h[x] = INIT + \sum_{i = 1}^{k}{C}^{h[son_i]} , 可取INIT = 1240198,C = 35224111 h[x]=INIT+∑i=1kCh[soni],可取INIT=1240198,C=35224111
- h [ x ] = I N I T + ∑ i = 1 k h [ s o n i ] ∗ p r i m e [ s z [ s o n i ] + o f f s e t ] , 可 取 I N I T = 1 或 1240198 , o f f s e t = 5000 , 注 意 不 能 从 第 一 个 质 数 开 始 取 h[x] = INIT + \sum_{i = 1}^{k}{h[son_i]}*prime[sz[son_i] + offset] , 可取INIT = 1或1240198,offset = 5000,注意不能从第一个质数开始取 h[x]=INIT+∑i=1kh[soni]∗prime[sz[soni]+offset],可取INIT=1或1240198,offset=5000,注意不能从第一个质数开始取
分别提供两种hash方法的AC代码
#include<bits/stdc++.h>
using namespace std;
#define PB push_back
#define SZ(x) (int)x.size()
#define se second
typedef long long ll;
const ll mod = 998244353;
const ll MOD = 180143985094819841ll;
//const ll MOD = 1e9 + 7;
const ll C = 35224111;
const ll INIT = 1240198;
const int maxn = 1e5 + 10;
vector <int> e[maxn];
ll fac[maxn],inv[maxn];
map <ll,int> num[maxn];
ll f[maxn],g[maxn],h1[maxn],h2[maxn],pw1[maxn],pw2[maxn];
map <ll,int> vis;
ll ans,seed;
int n,fa[maxn];
void clear(){
for (int i = 1 ; i <= n ; i++){
num[i].clear();
e[i].clear();
f[i] = g[i] = h1[i] = h2[i] = 0;
fa[i] = 0;
}
vis.clear();
ans = 0;
}
ll mul(ll a,ll b){
//return a * b % MOD;
return (a * b - (ll)(a / (long double)MOD * b + 1e-3)* MOD + MOD) % MOD;
}
ll powmod(ll x,ll y){
ll res = 1;
while ( y ) {
if ( y & 1 ) res = mul(res,x);
x = mul(x,x);
y >>= 1;
}
return res;
}
ll power(ll x,ll y){
ll res = 1;
while ( y ) {
if ( y & 1 ) res = res * x % mod;
x = x * x % mod;
y >>= 1;
}
return res;
}
void init(){
fac[0] = inv[0] = 1;
for (int i = 1 ; i < maxn ; i++) fac[i] = fac[i - 1] * i % mod , inv[i] = power(fac[i],mod - 2);
}
void dfs(int x){
h1[x] = INIT;
f[x] = 1;
int c = 0;
for (auto y : e[x]){
if ( y == fa[x] ) continue;
c++;
fa[y] = x;
dfs(y);
h1[x] = (h1[x] + pw1[y]) % MOD;
f[x] = f[x] * f[y] % mod;
num[x][h1[y]]++;
}
f[x] = f[x] * fac[c] % mod;
for (auto c : num[x]){
f[x] = f[x] * inv[c.se] % mod;
}
pw1[x] = powmod(C,h1[x]);
}
ll gethash(int x){
return (h1[x] + (fa[x] ? pw2[x] : 0)) % MOD;
}
void dfs2(int x){
if ( !fa[x] ){
if ( !vis[h1[x]] ){
ans = (ans + f[x]) % mod ,vis[h1[x]] = 1;
//cout<<x<<" "<<f[x]<<endl;
}
}
else{
ll d = gethash(x);
ll cur = f[x] * g[x] % mod;
cur = cur * inv[SZ(e[x]) - 1] % mod * fac[SZ(e[x])] % mod;
int &c = num[x][h2[x]];
cur = cur * fac[c] % mod * inv[c + 1] % mod;
if ( !vis[d] ) ans = (ans + cur) % mod , vis[d] = 1;
f[x] = cur , ++c;
// cout<<x<<" "<<f[x]<<endl;
}
for (int i = 0 ; i < SZ(e[x]) ; i++){
int y = e[x][i];
if ( y == fa[x] ) continue;
int c = num[x][h1[y]];
assert(c);
g[y] = f[x] * power(f[y],mod - 2) % mod * fac[c] % mod * inv[c - 1] % mod * inv[SZ(e[x])] % mod * fac[SZ(e[x]) - 1] % mod;
h2[y] = (h1[x] - pw1[y] + MOD + (fa[x] ? pw2[x] : 0)) % MOD;
pw2[y] = powmod(C,h2[y]);
}
for (auto y : e[x]){
if ( y == fa[x] ) continue;
dfs2(y);
}
}
int main(){
freopen("input.txt","r",stdin);
init();
int cases = 0;
scanf("%d",&cases);
while ( cases-- ){
scanf("%d",&n);
clear();
for (int i = 1 ; i < n ; i++){
int x,y;
scanf("%d %d",&x,&y);
e[x].PB(y) , e[y].PB(x);
}
dfs(1);
dfs2(1);
// for (int i = 1 ; i <= n ; i++) cout<<i<<" "<<f[i]<<endl;
// for (int i = 1 ; i <= n ; i++) cout<<i<<" "<<h1[i]<<" "<<h2[i]<<" "<<gethash(i)<<"\n";
// cout<<endl;
printf("%lld\n",ans);
}
}
#include <bits/stdc++.h>
using namespace std;
#define PB push_back
#define SZ(x) (int)x.size()
#define se second
typedef long long ll;
const ll mod = 998244353;
const ll MOD = 2305843009213693951ll;
const int offset = 5000;
const int maxn = 1e5 + 10;
const int maxm = maxn * 20;
vector<int> e[maxn];
ll fac[maxn], inv[maxn];
map<ll, int> num[maxn];
ll f[maxn], g[maxn], h1[maxn], h2[maxn];
map<ll, int> vis;
ll ans, seed[maxn];
int prime[maxm], tag[maxm], cnt;
int n, fa[maxn], sz[maxn];
void clear() {
for (int i = 1; i <= n; i++) {
num[i].clear();
e[i].clear();
sz[i] = 0, f[i] = g[i] = h1[i] = h2[i] = 0;
fa[i] = 0;
}
vis.clear();
ans = 0;
}
ll mul(ll a, ll b) { return (a * b - (ll)(a / (long double)MOD * b + 1e-3) * MOD + MOD) % MOD; }
ll power(ll x, ll y) {
ll res = 1;
while (y) {
if (y & 1)
res = res * x % mod;
x = x * x % mod;
y >>= 1;
}
return res;
}
void init() {
for (int i = 2; i < maxm; i++) {
if (!tag[i]) prime[++cnt] = i;
for (int j = 1; j <= cnt && prime[j] * i < maxm ; j++) {
tag[prime[j] * i] = 1;
if (i % prime[j] == 0) break;
}
}
assert(cnt >= maxn + offset);
fac[0] = inv[0] = 1;
for (int i = 1; i < maxn; i++) fac[i] = fac[i - 1] * i % mod, inv[i] = power(fac[i], mod - 2);
}
void dfs(int x) {
h1[x] = 1,f[x] = 1 , sz[x] = 1;
int c = 0;
for (auto y : e[x]) {
if (y == fa[x]) continue;
c++;
fa[y] = x;
dfs(y);
sz[x] += sz[y];
}
for (auto y : e[x]) {
if (y == fa[x]) continue;
h1[x] = (h1[x] + mul(h1[y], prime[sz[y] + offset])) % MOD;
f[x] = f[x] * f[y] % mod;
num[x][h1[y]]++;
}
f[x] = f[x] * fac[c] % mod;
for (auto c : num[x]) {
f[x] = f[x] * inv[c.se] % mod;
}
}
ll gethash(int x) {
if (!fa[x]) return h1[x];
return (h1[x] + mul(h2[x], prime[n - sz[x] + offset])) % MOD;
}
void dfs2(int x) {
ll d = gethash(x);
if (!fa[x]) {
if (!vis[h1[x]]) {
ans = (ans + f[x]) % mod, vis[h1[x]] = 1;
}
} else {
ll cur = f[x] * g[x] % mod;
cur = cur * inv[SZ(e[x]) - 1] % mod * fac[SZ(e[x])] % mod;
int &c = num[x][h2[x]];
cur = cur * fac[c] % mod * inv[c + 1] % mod;
if (!vis[d]) ans = (ans + cur) % mod, vis[d] = 1;
f[x] = cur, ++c;
}
for (int i = 0; i < SZ(e[x]); i++) {
int y = e[x][i];
if (y == fa[x]) continue;
int c = num[x][h1[y]];
assert(c);
g[y] = f[x] * power(f[y], mod - 2) % mod * fac[c] % mod * inv[c - 1] % mod * inv[SZ(e[x])] % mod *
fac[SZ(e[x]) - 1] % mod;
h2[y] = (d - mul(h1[y], prime[sz[y] + offset]) + MOD) % MOD;
}
for (auto y : e[x]) {
if (y == fa[x]) continue;
dfs2(y);
}
}
int main() {
init();
int cases;
scanf("%d", &cases);
for (int t = 1; t <= cases; t++) {
scanf("%d", &n);
clear();
for (int i = 1; i < n; i++) {
int x, y;
scanf("%d %d", &x, &y);
e[x].PB(y), e[y].PB(x);
}
dfs(1);
dfs2(1);
printf("%lld\n", ans);
}
}