WC2019数树

首先由于我们很难直接求解边集之交为S的树的个数,所以考虑容斥,转化成至少交S,这里需要推一波式子,然后利用组合意义,转化成一个连通块内给一个点染色,就可以O(n) DP了。
这类题能够转化成包含的好算一定要尝试着去凑一下,很有可能可以出奇迹。
然后再推一推就可以推出一个多项式exp的式子啦,感觉还是很妙妙的QAQ。
prufer序列那一部分可以用生成函数的exp来推。

#include <bits/stdc++.h>
using namespace std;

const int N = 2e6 + 5;
const int M = N * 2;
const int mod = 998244353;

#define SZ(x) (int) x.size()
#define REP(i, a, b) for(int i = (a); i <= (b); ++ i)
#define PER(i, a, b) for(int i = (a); i >= (b); -- i)
#define lc (no << 1)
#define rc (no << 1 | 1)
#define getmid int mid = (L[no] + R[no]) >> 1

int n, m, x, y, fir[N], ne[M], to[M], cnt, K, dp[N][2], now2, tmp[2], fac[N], inv[N];

const int i2 = (mod + 1) / 2;

int I[N];

const int g = 3;

vector <int> now;

namespace {
  int add(int x) {return (x >= mod) ? x - mod : x;}
  int sub(int x) {return (x < 0) ? x + mod : x;}
  void Add(int &x, int y) {x = add(x + y);}
  void Sub(int &x, int y) {x = sub(x - y);}
  int Pow(int x, long long y = mod - 2) {
    int res = 1;
    for(; y; y >>= 1, x = 1LL * x * x % mod) {
      if(y & 1) {
    res = 1LL * res * x % mod;
      }
    }
    return res;
  }
}

namespace FFT {
  int C[N], D[N], rev[N];

  void DFT(int *A, int up) {
    for(int i = 0; i < (1 << up); ++ i)
      rev[i] = rev[i >> 1] >> 1 | ((i & 1) * (1 << (up - 1)));
    for(int i = 0; i < (1 << up); ++ i)
      if(i < rev[i]) swap(A[i], A[rev[i]]);
    for(int i = 0; i < up; ++ i) {
      int wn = Pow(g, (mod - 1) / (1 << (i + 1)));
      for(int j = 0; j < (1 << up); j += (1 << (i + 1))) {
    int w = 1;
    for(int k = j; k < j + (1 << i); ++ k) {
      int L = A[k], R = 1LL * w * A[k + (1 << i)] % mod;
      A[k] = add(L + R);
      A[k + (1 << i)] = sub(L - R);
      w = 1LL * w * wn % mod;
    }
      }
    }
  }

  void IDFT(int *A, int up) {
    reverse(A + 1, A + (1 << up));
    DFT(A, up);
    int now = Pow(1 << up);
    for(int i = 0; i < (1 << up); ++ i)
      A[i] = 1LL * A[i] * now % mod;
  }

  vector <int> fixlen(vector <int> A, int len) {
    while(A.size() > len) A.pop_back();
    return A;
  }

  vector <int> operator * (vector <int> A, vector <int> B) {
    vector <int> ans; ans.clear();
    ans.resize(A.size() + B.size() - 1);
    int up = 0; while((1 << up) < ans.size()) ++ up;
    for(int i = 0; i < min((1 << (up + 1)), N); ++ i) C[i] = D[i] = 0;
    REP(i, 0, SZ(A) - 1) C[i] = A[i]; REP(i, 0, SZ(B) - 1) D[i] = B[i];
    DFT(C, up); DFT(D, up);
    REP(i, 0, (1 << up) - 1) C[i] = 1LL * C[i] * D[i] % mod;
    IDFT(C, up); REP(i, 0, SZ(ans) - 1) ans[i] = C[i];
    return ans;
  }
  
  vector <int> operator - (vector <int> A, vector <int> B) {
    while(A.size() < B.size()) A.push_back(0);
    REP(i, 0, SZ(B) - 1) A[i] = sub(A[i] - B[i]);
    return A;
  }

  vector <int> operator + (vector <int> A, vector <int> B) {
    while(A.size() < B.size()) A.push_back(0);
    REP(i, 0, SZ(B) - 1) A[i] = add(A[i] + B[i]);
    return A;    
  }
  
  vector <int> cons(int x) {
    vector <int> res; res.clear();
    res.push_back(x);
    return res;
  }

  void out(vector <int> who) {
    for(int i = 0; i < (int) who.size(); ++ i) cerr << who[i] <<" ";
    puts("--");
  }
  
  vector <int> inv(vector <int> A) {
    vector <int> B = cons(Pow(A[0]));
    while(B.size() < A.size()) B = fixlen((cons(2) - A * B) * B, B.size() * 2);
    return fixlen(B, A.size());
  }
  
  vector <int> Sqrt(vector <int> A) {
    vector <int> B; if(A[0] == 1) B.push_back(1); else throw;
    while(B.size() < A.size()) {
      int nlen = B.size() * 2;
      vector <int> tmp1 = fixlen(B * B, nlen);
      tmp1 = fixlen(tmp1 + A, nlen); B = inv(B);
      tmp1 = fixlen(tmp1 * B, nlen);
      for(int i = 0; i < (int) tmp1.size();++ i) tmp1[i] = 1LL * i2 * tmp1[i] % mod;
      swap(B, tmp1);
    }
    return fixlen(B, A.size());
  }

  vector <int> Dao(vector <int> A) {
    vector <int> B; B.clear();
    for(int i = 1; i < (int) A.size(); ++ i) B.push_back(1LL * A[i] * i % mod);
    return B;
  }

  vector <int> Ji(vector <int> A) {
    vector <int> B = cons(0);
    for(int i = 0; i < (int) A.size(); ++ i)
      B.push_back(1LL * I[i + 1] * A[i] % mod);
    return B;
  }

  vector <int> Ln(vector <int> A) {
    return fixlen(Ji(inv(A) * Dao(A)), A.size());
  }

  vector <int> exp(vector <int> A) {
    vector <int> f = cons(1);
    while(f.size() <= A.size() * 2) f = fixlen(f * (cons(1) + A - Ln(f)), 2 * f.size());
    return fixlen(f, A.size());
  }
}

void add(int x, int y) {
  ne[++ cnt] = fir[x];
  fir[x] = cnt;
  to[cnt] = y;
}

void link(int x, int y) {
  add(x, y);
  add(y, x);
}

#define Foreachson(i, x) for(int i = fir[x]; i; i = ne[i])

map <pair <int, int>, int> Map;

void dfs(int x, int f) {
  dp[x][1] = dp[x][0] = 1;
  Foreachson(i, x) {
    int V = to[i];
    if(V == f) continue;
    dfs(V, x);
   memset(tmp, 0, sizeof(tmp));
    for(int a = 0; a < 2; ++ a) {
      for(int b = 0; b < 2; ++ b) {
	if(a && b) continue;
	if(!b) {
	  Add(tmp[a | b], 1LL * now2 * dp[x][a] % mod * dp[V][b] % mod);
	  continue;
	}
	Add(tmp[0], 1LL * dp[x][a] % mod * dp[V][b] % mod);
	Add(tmp[1], 1LL * now2 * dp[x][a] % mod * dp[V][b] % mod);
      }
    }
    //cerr << x <<" " << V <<" " << tmp[0] << " " << tmp[1] << " " << now << " " << dp[x][0] <<" " << dp[x][1] << " " << dp[V][0] <<" " << dp[V][1] <<  endl;
    dp[x][0] = tmp[0]; dp[x][1] = add(1LL * dp[x][1] * dp[V][1] % mod +tmp[1]);
  }
}

int main() {
  freopen("tree.in", "r", stdin);
  freopen("tree.out", "w", stdout);
  fac[0] = 1;
  for(int i = 1; i < N; ++ i) fac[i] = 1LL * fac[i - 1] * i % mod;
  inv[N - 1] = Pow(fac[N - 1]);
  for(int i = N - 2; i >= 0; -- i) inv[i] = 1LL * (i + 1) * inv[i + 1] % mod;
  for(int i = 1; i < N; ++ i) I[i] = 1LL * inv[i] * fac[i - 1] % mod, assert(1ll * I[i] * i % mod == 1);
  int opt;
  cin >> n >> K >> opt;
  if(opt == 0) {
    for(int i = 1; i < n; ++ i) {
      scanf("%d%d", &x, &y); if(x > y) swap(x, y);
      Map[make_pair(x, y)] = 1;
    }
    int tot = n;
    for(int i = 1; i < n; ++ i) {
      scanf("%d%d", &x, &y);
      if(x > y) swap(x, y);
      if(Map[make_pair(x, y)]) -- tot;
    }
    printf("%d\n", Pow(K, tot));
  }
  else if(opt == 1) {
    for(int i = 1; i < n; ++ i) scanf("%d%d", &x, &y), link(x, y);
    now2 = 1LL * sub(1 - K) * Pow(K) % mod * Pow(n) % mod; 
    dfs(1, 0);
    printf("%lld\n", 1LL * dp[1][1] * Pow(K, n) % mod * Pow(n, n - 2) % mod);
  }
  else {
    if(K == 1) {
      cout << Pow(Pow(n, n - 2), 2) <<endl;
      return 0;
    }
    vector <int> A; A.clear();
    A.push_back(0);
    now2 = 1LL * sub(1 - K) * Pow(K) % mod * Pow(n) % mod;
    int Now = Pow(now2);
    //cerr << inv[2] << endl;
    for(int i = 1; i <= n; ++ i) {
      A.push_back(1LL * Now * inv[i] % mod * Pow(i, i) % mod * n % mod);
    }
    //for(int i = 0; i < (int) A.size(); ++ i) cerr << A[i] << " ";
    //cerr << endl;
    A = FFT :: exp(A);
    //for(int i = 0; i < (int) A.size(); ++ i) cerr << A[i] <<" ";
    //cerr << endl;
    int ans = A[n];
    cout << 1LL * ans * fac[n] % mod
      * Pow(sub(1 - K), n) % mod * Pow(Pow(n, 4)) % mod << endl;
  }
}
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值