学习蓝皮书看到的一个题
题目链接
我们不妨把每个附加边当作是实边的一次“覆盖”。 每个附加边将会覆盖掉从从附加边两个端点开始至两端点的最近公共祖先(LCA)所经过的所有边,如图:
“覆盖”的意义是:
如切开被覆盖的边(且该边只被覆盖一次),能且只能切开把它覆盖的那条附加边,从而将图切成两部分。
而且
如果某实边的覆盖次数为0, 则切开它后,再切开任意附加边就可以将图切成两部分。
我们设数组 dif[N] 记录差分,对于每个端点为 x, y 的附加边,有
dif[x]++;
dif[y]++;
dif[lca(x, y)] -= 2;
再进行一次dfs得到每个边的覆盖值 ans[N] (为了方便,把每个边的覆盖值看做是词边连接的子节点的权值,此时注意排除树根的影响,不妨设1节点为树根)。
下面是ac代码:
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <vector>
#include <queue>
#include <cstdio>
#define ll long long
using namespace std;
const int N = 1e5+5;
int f[N][20], d[N], ans[N], dif[N];
int ne[N*2], he[N], ver[N*2];
int n, m;
int cnt;
queue<int> q;
int t, tot;
void add(int x, int y)
{
ver[++tot] = y;
ne[tot] = he[x];
he[x] = tot;
}
void bfs()
{
d[1] = 1;
q.push(1);
while(q.size())
{
int te = q.front();
q.pop();
for (int i = he[te]; i; i = ne[i])
{
int v = ver[i];
if (d[v]) continue;
d[v] = d[te] + 1;
f[v][0] = te;
for (int j = 1; j <= t; j++)
f[v][j] = f[f[v][j-1]][j-1];
q.push(v);
}
}
}
int lca(int x, int y)
{
if (d[x] > d[y]) swap(x, y);
for (int i = t; i >= 0; i--)
{
if (d[f[y][i]] < d[x]) continue;
y = f[y][i];
}
if (x == y) return x;
for (int i = t; i >= 0; i--)
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
int sum = 0;
void dfs(int cur)
{
int sum = 0;
for (int i = he[cur]; i; i = ne[i])
{
int v = ver[i];
if (d[v] < d[cur]) continue;
dfs(v);
sum += ans[v];
}
sum += dif[cur];
ans[cur] = sum;
if (cur != 1 && ans[cur] == 1) cnt++;
if (cur != 1 && ans[cur] == 0) cnt += m;
}
int main()
{
while(cin >> n >> m)
{
memset(d, 0, sizeof(d));
memset(f, 0, sizeof(f));
t = (log(n) / log(2)) + 1;
memset(ne, 0, sizeof(ne));
memset(dif, 0, sizeof(dif));
memset(he, 0 , sizeof(he));
sum = 0;
cnt = 0;
tot = 0;
while(q.size()) q.pop();
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
add(x, y);
add(y, x);
}
bfs();
for (int i = 0; i < m; i++)
{
int x, y;
scanf("%d%d", &x, &y);
dif[x]++, dif[y]++;
dif[lca(x, y)] -= 2;
}
/* for (int i = 1; i <= n; i++)
{
cout << i << ":" << dif[i] << endl;
}
cout <<endl;*/
dfs(1);
/*for (int i = 1; i <= n; i++)
{
cout << i << ":" << ans[i] << endl;
}*/
printf("%d\n", cnt);
}
return 0;
}
/*
input
9 2
1 2
1 3
1 4
2 5
2 6
4 7
4 8
7 9
6 7
8 9
output
9
*/