目录
预处理:
void solve()
{
int n, k;
cin >> n >> k;
vector<vector<int>>alist(n + 1);
for (int i = 1; i <= n - 1; i++)
{
int x, y;
cin >> x >> y;
alist[x].push_back(y);
alist[y].push_back(x);
}
vector<vector<int>>fa(n + 1, vector<int>(22));
vector<int>dep(n + 1);
vector<int>leaf;
auto dfs = [&](int cur, int pa, auto dfs)->void {
dep[cur] = dep[pa] + 1;
if (alist[cur].size() == 1)
{
leaf.push_back(cur);
}
for (auto x : alist[cur])
{
if (x != pa)
{
fa[x][0] = cur;
dfs(x, cur, dfs);
}
}
};
dfs(1, 0, dfs);
//倍增求父亲
for (int p = 1; p < 22; p++)
{
for (int i = 1; i <= n; i++)
{
fa[i][p] = fa[fa[i][p - 1]][p - 1];
}
}
LCA:
auto LCA = [&](int a, int b)->pair<int, int>
{
ll ans = 0;
if (dep[a] < dep[b])swap(a, b);
while (dep[a] > dep[b])
{
int dis = (int)log2(dep[a] - dep[b]);
a = fa[a][dis];
ans += pow(2, dis);
}
for (int i = log2(dep[a]); i >= 0; i--)
{
if (fa[a][i] != fa[b][i]) //向上结果不同才跳
{
a = fa[a][i];
b = fa[b][i];
ans += pow(2, i) * 2;
}
}
if (a != b)
{
a = b = fa[a][0];
ans += 2;
}
return { a ,ans };
};
第一个让a,b深度相同的操作是while进行的,二进制拆分。
主函数外版:
inline int LCA(int a,int b)
{
if (dep[a] < dep[b])swap(a, b);
while (dep[a] > dep[b])
{
int dis = (int)log2(dep[a] - dep[b]);
a = fa[a][dis];
}
if (a == b)return a;
for (int i = log2(dep[a]); i >= 0; i--)
{
if (fa[a][i] != fa[b][i])
{
a = fa[a][i];
b = fa[b][i];
}
}
if (a != b)
{
a = b = fa[a][0];
}
return a;
}
使用示范:(链式前向星等,POJ 3417 Network 树上差分 LCA 链式前向星 仍超时,加上快读过-CSDN博客)
struct node
{
int b, next;
}edges[maxn*2];
int head[maxn];
int k = 0;
inline void add(int a, int b)
{
k++;
edges[k].b = b;
edges[k].next = head[a];
head[a] = k;
}
int fa[maxn][23];
int dep[maxn];
void dfs1(int cur ,int pa)
{
dep[cur] = dep[pa] + 1;
fa[cur][0] = pa;
for (int p = 1; p < 22; p++)
fa[cur][p] = fa[fa[cur][p - 1]][p - 1];
for (int i = head[cur]; i > 0; i = edges[i].next)
{
if(edges[i].b != pa)
dfs1(edges[i].b, cur);
}
}
inline int LCA(int a,int b)
{
if (dep[a] < dep[b])swap(a, b);
while (dep[a] > dep[b])
{
int dis = (int)log2(dep[a] - dep[b]);
a = fa[a][dis];
}
if (a == b)return a;
for (int i = log2(dep[a]); i >= 0; i--)
{
if (fa[a][i] != fa[b][i])
{
a = fa[a][i];
b = fa[b][i];
}
}
if (a != b)
{
a = b = fa[a][0];
}
return a;
}
int diff[maxn];
int pass[maxn];
int cnt0, cnt1;
void dfs2(int cur,int pa)
{
pass[cur] += diff[cur];
for (int i = head[cur]; i > 0; i = edges[i].next)
{
if (edges[i].b != pa)
{
dfs2(edges[i].b, cur);
pass[cur] += pass[edges[i].b];
}
}
}