题目传送门:Network
题目大意:
给你一棵树, n - 1条主要边, m条附加边, 你的任务是切断这棵树, 即将这棵树分成两棵树. 因为结构是树, 所以无论你切割哪条主要边, 都会将这个数切断, 现在的问题是还需要在切割主要边的基础上再切割一条附加边(附加边只在切割主要边后起连接作用), 若仍然可以切断这棵树, 就认为你完成了任务.
方案1: 切割一条主要边, 若该主要边不被任何附加边覆盖, 那么下一步切割任意一条附加边都算完成任务.
方案2: 切割一条主要边, 若该条主要边仅被一条附加边覆盖, 那么只需切割该附加边即可完成任务.(若存在被两条及以上的附加边覆盖, 那无论下一步切割哪一条,都无力回天了, 只能gg ).
LCA的作用: 可以计算任意两节点路径上的权值和, 因为是树形结构, 所以任意两节点的路径是固定的, dis(x, y) = dis[x] + dis[y] - 2 * dis[LCA(x, y)], dis[]数组记录节点到根节点路径的权值和.
差分的作用: 差分前缀和是原序列, 在该题中我们可以用差分数组标记附加边覆盖主要边的信息, 具体操作为 cnt[x] + 1, cnt[y] + 1, cnt[LCA(x, y)] - 2, x到LCA(x, y)路径固定, y到LCA(x, y)的路径固定, 就相当于在两条序列上做差分运算.
两个最重要的框架已经就位了, 就只剩下如何组合了
首先用bfs求出树上每个节点的老子 爸爸以及每个节点的深度, 然后dfs求出附加边的树上差分前缀和, 之后的差分数组就是以该节点为根节点的子树的权值和, 即该节点到父节点被附加边覆盖的次数.因为题目中没有规定根节点, 所以可以任选一个节点当根节点.
代码如下 :
#include<iostream>
#include<algorithm>
#include<string>
#include<vector>
#include<queue>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> P;
const int MAX_N = 2e5 + 5;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + 7;
//
//
int n, m, t, tot;
int cnt[MAX_N];
int f[MAX_N][30];
int head[MAX_N], d[MAX_N];
bool used[MAX_N];
struct node {int to, next;} G[MAX_N];
inline void add_edge(int u, int v) {
++ tot;
G[tot].to = v;
G[tot].next = head[u];
head[u] = tot;
}
void bfs() {
memset(d, 0, sizeof(d));
queue<int> q;
q.push(1);
d[1] = 1;
while(!q.empty()) {
int p = q.front();
q.pop();
for(int i = head[p]; i != -1; i = G[i].next) {
int u = G[i].to;
if(d[u]) continue;//防止重复计算
d[u] = d[p] + 1;
f[u][0] = p;
for(int j = 1; j <= t; ++ j) {
//f[u][j]代表u向根节点的方向行走2^j距离所达到的节点, 俗称u的2^j级父亲
//lca的核心代码, u的2^j级父亲, 等于u的2^(j-1)级父亲的2^(j-1)父亲
f[u][j] = f[f[u][j - 1]][j - 1];
}
q.push(u);
}
}
}
void dfs(int x) {
used[x] = true;
for(int i = head[x]; i != -1; i = G[i].next) {
int u = G[i].to;
if(used[u]) continue;
dfs(u);
cnt[x] += cnt[u];
}
}
int lca(int x, int y) {
if(d[x] < d[y]) swap(x, y);
//使x和y达到同一深度
for(int i = t; i >= 0; -- i) {
if(d[f[x][i]] >= d[y]) x = f[x][i];
}
if(x == y) return x;
for(int i = t; i >= 0; -- i) {
if(f[x][i] != f[y][i]) {
x = f[x][i];
y = f[y][i];
}
}
return f[x][0];
}
void solve() {
scanf("%d %d", &n, &m);
memset(head, -1, sizeof(head));
t = (int)(log(n) / log(2)) + 1;// lca采用的是倍增思想
for(int i = 1; i < n; ++ i) {
int u, v;
scanf("%d %d", &u, &v);
add_edge(u, v);
add_edge(v, u);
}
bfs();
for(int i = 1; i <= m; ++ i) {
int u, v;
scanf("%d %d", &u, &v);
++ cnt[u];
++ cnt[v];
cnt[lca(u, v)] -= 2;
}
dfs(1);
int ans = 0;
for(int i = 2; i <= n; ++ i) {
if(cnt[i] == 0) ans += m;
if(cnt[i] == 1) ++ ans;
}
printf("%d\n", ans);
}
signed main() {
int test = 1;
//scanf("%d", &test);
while(test -- ) {
solve();
}
return 0;
}