【题目链接】
【思路要点】
- 预处理\(i^k\)的前缀和,询问时查询询问点对的LCA,即可将问题转化为\(i^k\)的区间和。
- 时间复杂度\(O(NK+NLogN+QLogN)\)。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 300005; const int MAXLOG = 20; const int MAXK = 55; const int P = 998244353; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } vector <int> a[MAXN]; int n, m, val[MAXK][MAXN]; int father[MAXN][MAXK], depth[MAXN]; void work(int pos, int fa, int dep) { depth[pos] = dep; father[pos][0] = fa; for (int i = 1; i < MAXLOG; i++) father[pos][i] = father[father[pos][i - 1]][i - 1]; for (unsigned i = 0; i < a[pos].size(); i++) if (a[pos][i] != fa) work(a[pos][i], pos, dep + 1); } int lca(int x, int y) { if (depth[x] < depth[y]) swap(x, y); for (int i = MAXLOG - 1; i >= 0; i--) if (depth[father[x][i]] >= depth[y]) x = father[x][i]; if (x == y) return x; for (int i = MAXLOG - 1; i >= 0; i--) if (father[x][i] != father[y][i]) { x = father[x][i]; y = father[y][i]; } return father[x][0]; } int main() { read(n); for (int i = 1; i <= n - 1; i++) { int x, y; read(x), read(y); a[x].push_back(y); a[y].push_back(x); } for (int i = 1; i <= n; i++) { val[0][i] = 1; for (int j = 1; j < MAXK; j++) val[j][i] = 1ll * val[j - 1][i] * i % P; } for (int i = 1; i < MAXK; i++) for (int j = 1; j <= n; j++) val[i][j] = (val[i][j] + val[i][j - 1]) % P; depth[0] = -1; work(1, 0, 0); read(m); for (int i = 1; i <= m; i++) { int x, y, k; read(x), read(y), read(k); int z = lca(x, y); int ansx = (val[k][depth[x]] - val[k][depth[z]] + P) % P; int ansy = (val[k][depth[y]] - val[k][depth[z]] + P) % P; int ans = (ansx + ansy) % P; if (depth[z] != 0) { int tmp = (val[k][depth[z]] - val[k][depth[z] - 1] + P) % P; ans = (ans + tmp) % P; } writeln(ans); } return 0; }