题目:http://poj.org/problem?id=3764
题目没说清path是不是只算从root到leaf的path,如果的是的话,直接建树dfs即的结果,看了discuss知道不是这么简单,还是需要考虑任意两点间的xorpath,由于xor两次即抵消的性质,xorpath(a, b) = xorpath(root, a) ^ xorpath(root, b),但是直接枚举O(N*N)肯定TLE,搜了题解才知道这题竟然可以用Trie来做:
(1)设y = xorpath(root, a),将y按照如下方式插入Trie,若该bit等于0,则插入左孩子,若该bit等于1,则插入有孩子,这样将所有的xorpath(root, *)插入Trie之后(高度显然是31);
(2)对于以a为一端的xorpath,我们只需在Trie中找能让y = xorpath(root, a)异或之后最大的路径,即若y的该bit等于0,我们试图继续在右孩子(表示存在该bit等于1的path)中继续找,若y的该bit等于1,我们试图继续在左孩子(表示存在该bit等于0的path)中继续找;
这样建树O(N),插入一共O(31N),查询一共O(31N),整体一共O(63N)
#include <cstdio>
#include <cstring>
const int MAX_NODE = 100005;
const int MAX_EDGE = MAX_NODE - 1;
inline int max(int a, int b){ return a > b ? a : b; }
//binary trie
struct TrieNode{
TrieNode* ch[2];
TrieNode(){
ch[0] = ch[1] = NULL;
}
} node[MAX_NODE * 31];
int nex;
TrieNode* newTrieNode()
{
node[nex].ch[0] = node[nex].ch[1] = NULL;
return node + nex++;
}
//xor path from root to each node
int N, xorPath[MAX_NODE];
TrieNode* root;
//tree edge
struct Edge
{
int to, weight, next;
} edge[MAX_EDGE * 2];
int cnt, pre[MAX_NODE];
inline void addEdge(int x, int y, int w)
{
edge[cnt].to = y;
edge[cnt].weight = w;
edge[cnt].next = pre[x];
pre[x] = cnt++;
}
/***************************** sub functions ****************************/
void build()
{
cnt = 0;
memset(pre, -1, N << 2);
int i = 1, x, y, w;
for(; i < N; ++i){
scanf("%d%d%d", &x, &y, &w);
addEdge(x, y, w);
addEdge(y, x, w);
}
}
void dfs(int x, int fa, int xorSum)
{
xorPath[x] = xorSum;
for(int i = pre[x]; i != -1; i = edge[i].next){
int y = edge[i].to;
if(y == fa) continue;
dfs(y, x, xorSum ^ edge[i].weight);
}
}
void insertTrie(int v)
{
TrieNode* p = root;
for(int i = 30; i >= 0; --i){
int j = !!(1 << i & v);
if(NULL == p->ch[j]) p->ch[j] = newTrieNode();
p = p->ch[j];
}
}
int findTrie(int v)
{
int res = 0;
TrieNode* p = root;
for(int i = 30; i >= 0; --i){
int j = !(1 << i & v);
if(NULL != p->ch[j]){
res |= 1 << i;
p = p->ch[j];
}
else p = p->ch[!j];
}
return res;
}
int main()
{
int i, res;
while(~scanf("%d", &N)){
//build tree
build();
//find each node's xor sum from root
dfs(0, -1, 0);
//build trie
nex = 0;
root = newTrieNode();
for(i = 0; i < N; ++i) insertTrie(xorPath[i]);
//find max xor path between any two nodes
res = 0;
for(i = 0; i < N; ++i) res = max(res, findTrie(xorPath[i]));
printf("%d\n", res);
}
return 0;
}