题意:给定一颗n个节点的树,然后是m个询问,每个询问是两个集合,问分别在这两个集合中取一个点,这两个点的LCA 的深度最大为多少。
分析:首先想到的肯定是for两个集合,O(k^2),肯定要爆炸。我们对这棵树进行树链剖分,剖完之后的树最多有logn条链,那么,我们将这两个集合中的点,根据他们所在的链分类,如果两个点位于同一条链,我们取深度更深的那个点,这样处理之后,每个集合中最多只有logn个点,再去for,时间看起来就比较合理了,还有,遍历过程中需要优化以下,就是当for到两个点,其中有一个的深度小于等于当前ans就跳过,(刚开始因为没加这个T了。。) 复杂度大概是O((logn)^3),可能还要乘个大常数。。
以下是代码:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<queue>
#include<map>
#include<set>
#include<stack>
#include<cstring>
#include<string>
#include<vector>
#include<iomanip>
//#include<unordered_set>
//#include<unordered_map>
#include<cmath>
#include<list>
#include<bitset>
using namespace std;
#define _____ ios::sync_with_stdio(false); cin.tie(0);
#define ull unsigned long long
#define ll long long
#define lson l,mid,id<<1
#define rson mid+1,r,id<<1|1
typedef pair<int, int>pii;
typedef pair<ll, ll>pll;
typedef pair<double, double>pdd;
const double eps = 1e-6;
const int MAXN = 100005;
const int MAXM = 5005;
const ll LINF = 0x3f3f3f3f3f3f3f3f;
const int INF = 0x3f3f3f3f;
const double FINF = 1e18;
const ll MOD = 1000000007;
int dep[MAXN];
int siz[MAXN];
int fa[MAXN];
int top[MAXN];
int son[MAXN];
int tot, n, m, x, y, cnt, tmp;
pii cntA[MAXN], cntB[MAXN];
bool visA[MAXN], visB[MAXN];
vector<int>vec[MAXN];
void dfs1(int u, int pre, int deep)
{
siz[u] = 1;
fa[u] = pre;
dep[u] = deep;
son[u] = 0;
for (int i = 0; i < vec[u].size(); ++i)
{
int v = vec[u][i];
if (v == pre)continue;
dfs1(v, u, deep + 1);
siz[u] += siz[v];
if (siz[v] > siz[son[u]] || son[u] == 0)son[u] = v;
}
}
void dfs2(int u, int tp,int pre)
{
top[u] = tp;
if (son[u])dfs2(son[u], tp, u);
for (int i = 0; i < vec[u].size(); ++i)
{
if (vec[u][i] == pre || vec[u][i] == son[u])continue;
dfs2(vec[u][i], vec[u][i], u);
}
}
int findfa(int u, int v)
{
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]])swap(u, v);
u = fa[top[u]];
}
return min(dep[u], dep[v]);
}
int main()
{
while (scanf("%d%d", &n, &m) != EOF)
{
memset(visA, 0, sizeof(visA));
memset(visB, 0, sizeof(visB));
for (int i = 1; i <= n; ++i)vec[i].clear();
for (int i = 0; i < n - 1; ++i)scanf("%d%d", &x, &y), vec[x].push_back(y), vec[y].push_back(x);
dfs1(1, 1, 1);
dfs2(1, 1, 1);
while (m--)
{
vector<int>pointA, pointB;
scanf("%d", &cnt);
for (int i = 0; i < cnt; ++i)
{
scanf("%d", &tmp);
if (visA[top[tmp]])
{
if (cntA[top[tmp]].first < dep[tmp])
{
cntA[top[tmp]].first = dep[tmp];
cntA[top[tmp]].second = tmp;
}
}
else
{
cntA[top[tmp]].first = dep[tmp];
cntA[top[tmp]].second = tmp;
pointA.push_back(top[tmp]);
visA[top[tmp]] = 1;
}
}
scanf("%d", &cnt);
for (int i = 0; i < cnt; ++i)
{
scanf("%d", &tmp);
if (visB[top[tmp]])
{
if (cntB[top[tmp]].first < dep[tmp])
{
cntB[top[tmp]].first = dep[tmp];
cntB[top[tmp]].second = tmp;
}
}
else
{
cntB[top[tmp]].first = dep[tmp];
cntB[top[tmp]].second = tmp;
pointB.push_back(top[tmp]);
visB[top[tmp]] = 1;
}
}
int ans = 0;
for (int i = 0; i < pointA.size(); ++i)
{
for (int j = 0; j < pointB.size(); ++j)
{
int u = cntA[pointA[i]].second, v = cntB[pointB[j]].second;
//cout << pointA[i] << " " << pointB[j] << " | " << u << " " << v << endl;
if (dep[u] <= ans || dep[v] <= ans)continue;
ans = max(ans, findfa(u, v));
}
}
for (int i = 0; i < pointA.size(); ++i)visA[pointA[i]] = 0;
for (int i = 0; i < pointB.size(); ++i)visB[pointB[i]] = 0;
printf("%d\n", ans);
}
}
}