题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4601
题意很简单就是给出一个类似trie的树形结构, 每个节点表示一个单词, 要求查询从某个点u向下走m步可以获得的字典序最大的字符串的hash值。做法就是建出一棵与原树对应的trie树, 对其dfs获得原树每个节点对应单词的rank值, 考虑到查询的特殊性然后对原树进行层次遍历, 使得同一层的点在一段连续的区间, 每次查询深度为dep[u] + m的区间的最值, 但要先确定节点u对应的那个子区间, 因为同一深度的连续区间内节点的dfs时间戳满足单调性, 所以可以二分得到区间的左右端点。
注意:1.数据中有m = 0的情况, 而且m = 0时应输出0。
2.dfs参数多的话可能会爆栈, 需要扩栈。
代码很丑。。。
#include <iostream>
#include <queue>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
#include <vector>
#include <cmath>
using namespace std;
const int N = 100005;
const int M = N << 1;
const int mod = 1000000007;
typedef long long LL;
int fpow(int a, int p) {
LL res = 1;
LL t = a;
while (p) {
if (p & 1) {
res *= t;
res %= mod;
}
p >>= 1;
t *= t;
t %= mod;
}
return res;
}
int Map[256];
int Log[N];
void init() {
for (int i = 0; i < 26; i++)
Map[i + 'a'] = i;
Log[0] = -1;
for (int i = 1; i < N; i++) {
Log[i] = Log[i >> 1] + 1;
}
}
struct ST {
int d[N][17];
int pos[N][17];
int n;
void init(int n, int *A) {
this->n = n;
for (int i = 1; i <= n; i++) {
d[i][0] = A[i];
pos[i][0] = i;
}
for (int j = 1; j <= Log[n]; j++)
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
//d[i][j] = min(d[i][j - 1], d[i + (1 << j - 1)][j - 1]);
if (d[i][j - 1] > d[i + (1 << j - 1)][j - 1]) {
d[i][j] = d[i][j - 1];
pos[i][j] = pos[i][j - 1];
}
else {
d[i][j] = d[i + (1 << j - 1)][j - 1];
pos[i][j] = pos[i + (1 << j - 1)][j - 1];
}
}
}
int query(int st, int ed) {
int m = Log[ed - st + 1];
//return min(d[st][m], d[ed - (1 << m) + 1][m]);
if (d[st][m] > d[ed - (1 << m) + 1][m]) {
return pos[st][m];
}
else {
return pos[ed - (1 << m) + 1][m];
}
}
}rmq;
struct Tree {
int head[N], to[M], next[M];
char chr[M];
int ch[N][26];
LL hash[N];
int rank[N], dfn[N];
int idx[N];
int list[N];
int pos[N];
int key[N];
int last[N], dep[N];
int lev[N];
int levst[N];
queue<int> Q;
int p;
int n, tot, r, tdfn;
int sz;
void init(int n) {
this->n = n;
for (int i = 1; i <= n; i++) {
head[i] = -1;
}
tot = 0;
sz = 1;
memset(ch[0], 0, sizeof(ch[0]));
tdfn = 1;
}
void add(int u, int v, char c) {
chr[tot] = c, to[tot] = v, next[tot] = head[u], head[u] = tot++;
chr[tot] = c, to[tot] = u, next[tot] = head[v], head[v] = tot++;
}
void dfs(int u, int fa, LL h, int u2, int d) {
hash[u] = h;
list[u] = u2;
dfn[u] = tdfn++;
int son = 0;
dep[u] = -1;
for (int i = head[u]; i != -1; i = next[i]) {
int v = to[i];
if (v == fa) continue;
son++;
int t = Map[chr[i]];
if (!ch[u2][t]) {
memset(ch[sz], 0, sizeof(ch[sz]));
ch[u2][t] = sz;
sz++;
}
dfs(v, u, (h * 26 + (chr[i] - 'a')) % mod, ch[u2][t], d + 1);
dep[u] = max(dep[u], dep[v]);
}
last[u] = tdfn - 1;
if (son == 0) {
dep[u] = d;
}
}
void dfs(int u) {
rank[u] = r;
r++;
for (int i = 0; i < 26; i++) {
if (ch[u][i]) {
dfs(ch[u][i]);
}
}
}
void bfs() {
Q.push(1);
for (int i = 1; i <= n; i++) {
pos[i] = -1;
levst[i] = -1;
}
p = 1;
pos[1] = p++;
lev[1] = 0;
levst[0] = 1;
while (!Q.empty()) {
int cur = Q.front();
Q.pop();
if (levst[lev[cur]] == -1) {
levst[lev[cur]] = pos[cur];
}
for (int i = head[cur]; i != -1; i = next[i]) {
int v = to[i];
if (pos[v] == -1) {
lev[v] = lev[cur] + 1;
Q.push(v);
pos[v] = p++;
}
}
}
for (int i = 0; i <= n; i++) {
if (levst[i] == -1) {
levst[i] = p;
break;
}
}
}
void gao() {
dfs(1, -1, 0, 0, 0);
r = 0;
dfs(0);
bfs();
for (int i = 1; i <= n; i++) {
key[pos[i]] = rank[list[i]];
idx[pos[i]] = i;
//cout << i << ' ' << pos[i] << ' ' << key[pos[i]] << endl;
}
rmq.init(n, key);
}
int query(int u, int m) {
if (m == 0) return 0;
int t = lev[u] + m;
if (dep[u] < t) return -1;
int L = levst[t], R = levst[t + 1] - 1;
int st = L, ed;
while (L <= R) {
int mid = L + R >> 1;
if (dfn[idx[mid]] <= last[u])
L = mid + 1;
else
R = mid - 1;
}
ed = R;
L = st;
while (L <= R) {
int mid = L + R >> 1;
if (dfn[idx[mid]] >= dfn[u])
R = mid - 1;
else
L = mid + 1;
}
st = L;
int p = rmq.query(st, ed);
LL res = hash[idx[p]] - ((hash[u] * fpow(26, m)) % mod);
res %= mod;
if (res < 0)
res += mod;
return res;
}
}T;
int main() {
int size = 256 << 20; // 256MB
char *p = (char*)malloc(size) + size;
__asm__("movl %0, %%esp\n" :: "r"(p) );
int test, n, u, m, q, v;
char s[3];
scanf("%d", &test);
init();
while (test--) {
scanf("%d", &n);
T.init(n);
for (int i = 0; i < n- 1; i++) {
scanf("%d%d%s", &u, &v, s);
T.add(u, v, s[0]);
}
T.gao();
scanf("%d", &q);
while (q--) {
scanf("%d%d", &u, &m);
int tmp = T.query(u, m);
if (tmp == -1)
puts("IMPOSSIBLE");
else {
printf("%d\n", tmp);
}
}
}
return 0;
}