HDU 4822 Tri-war 解题报告(LCA)

Time Limit: 20000/10000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others)
Total Submission(s): 125    Accepted Submission(s): 37

Problem Description
Three countries, Red, Yellow, and Blue are in war. The map of battlefield is a tree, which means that there are N nodes and (N – 1) edges that connect all the nodes. Each country has a base station located in one node. All three countries will not place their station in the same node. And each country will start from its base station to occupy other nodes. For each node, country A will occupy it iff other two country's base stations have larger distances to that node compared to country A. Note that each edge is of the same length.

Given three country's base station, you task is to calculate the number of nodes each country occupies (the base station is counted).

The input starts with a single integer T (1 ≤ T ≤ 10), the number of test cases.

Each test cases starts with a single integer N (3 ≤ N ≤ 10 ^ 5), which means there are N nodes in the tree.

Then N - 1 lines follow, each containing two integers u and v (1 ≤ u, v ≤ N, u ≠ v), which means that there is an edge between node u and node v.

Then a single integer M (1 ≤ M ≤ 10 ^ 5) follows, indicating the number of queries.

Each the next M lines contains a query of three integers a, b, c (1 ≤ a, b, c ≤ N, a, b, c are distinct), which indicates the base stations of the three countries respectively.

For each query, you should output three integers in a single line, separated by white spaces, indicating the number of nodes that each country occupies. Note that the order is the same as the country's base station input.

Sample Input
1 9 1 2 1 3 1 4 2 5 2 6 2 7 6 8 6 9 2 1 2 8 2 1 4

Sample Output
3 3 1 6 2 1





#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#include <string>
#include <iomanip>
#include <cassert>
using namespace std;
#pragma comment(linker, "/STACK:1024000000,1024000000")
#define ff(i, n) for(int i=0;i<(n);i++)
#define fff(i, n, m) for(int i=(n);i<=(m);i++)
#define dff(i, n, m) for(int i=(n);i>=(m);i--)
#define bit(n) (1LL<<(n))
typedef long long LL;
typedef unsigned long long ULL;
void work();
int main()
#ifdef ACM
    freopen("in.txt", "r", stdin);
#endif // ACM


void nextInt(int & x)
    char ch;
    while(ch = getchar(), isdigit(ch) == false);

    x = 0;
    while(x = 10 * x + ch - '0', ch = getchar(), isdigit(ch) == true);


const int maxv = 100010;
const int maxe = 200020;
const int maxlog = 20;
int n, m;

int first[maxv], ecnt;
int vv[maxe], nxt[maxe];

void init()
    memset(first, 0, sizeof(first));
    ecnt = 2;

void addEdge(int u, int v)
    nxt[ecnt] = first[u], vv[ecnt] = v, first[u] = ecnt ++;

int fa[maxlog][maxv];
int dep[maxv], size[maxv];

void dfs(int u, int f, int d)
    fa[0][u] = f, size[u] = 1, dep[u] = d;

    for(int e = first[u]; e; e = nxt[e]) if(vv[e] != f)
        dfs(vv[e], u, d + 1);
        size[u] += size[vv[e]];

void initFa()
    dfs(1, -1, 0);
    ff(k, maxlog-1) fff(u, 1, n) if(fa[k][u] == -1)
        fa[k+1][u] = -1;
        fa[k+1][u] = fa[k][fa[k][u]];

int upSlope(int u, int p)
    assert(p <= dep[u]);

    ff(k, maxlog) if(p & bit(k))
        u = fa[k][u];
    return u;

int lca(int u, int v)
    if (dep[u] < dep[v]) swap(u, v);
    u = upSlope(u, dep[u] - dep[v]);
    if (u == v) return u;
    dff(k, maxlog-1, 0) if (fa[k][u] != fa[k][v])
        u = fa[k][u], v = fa[k][v];
    return fa[0][u];

struct Node
    int type, r;
    Node(int type, int r) : type(type), r(r) {}

Node getMiddle(int a, int b, int ab)
    int d = dep[a] + dep[b] - 2 * dep[ab];
    if (dep[a] >= dep[b])
        return Node(1, upSlope(a, (d - 1)/2));
        return Node(2, upSlope(b, d/2));

int calc(int a, int b, int c, int ab, int ac)
    Node bn = getMiddle(a, b, ab);
    Node cn = getMiddle(a, c, ac);

    if (bn.type + cn.type == 2)
        return min(size[bn.r], size[cn.r]);
    else if (bn.type + cn.type == 4)
        if(dep[bn.r] > dep[cn.r]) swap(bn, cn);
        if (lca(bn.r, cn.r) == bn.r)
            return n - size[bn.r];
            return n - size[bn.r] - size[cn.r];
        if (bn.type == 2) swap(bn, cn);
        if (lca(bn.r, cn.r) == bn.r)
            return size[bn.r] - size[cn.r];
            return size[bn.r];

void work()
    int T;
    scanf("%d", &T);
    fff(cas, 1, T)

        scanf("%d", &n);
        ff(i, n-1)
            int u, v;
            scanf("%d%d", &u, &v);
            addEdge(u, v);
            addEdge(v, u);


        scanf("%d", &m);
            int a, b, c;
            scanf("%d%d%d", &a, &b, &c);

            int ab = lca(a, b);
            int ac = lca(a, c);
            int bc = lca(b, c);

            printf("%d %d %d\n", calc(a, b, c, ab, ac), calc(b, a, c, ab, bc), calc(c, a, b, ac, bc));





