题目:https://vjudge.net/problem/POJ-3417#author=0
题目大意:在一棵n个点的树上由n-1条强边连接,现在新添m条弱边,需要切掉一条强边以及一条弱边,使得树变为完全无联系的两部分,问一共有多少种不同的切法
对于每一条新加入的弱边,都会跟原来的强边形成环
(绿色为弱边,黑色为强边)
假如需要切除一条强边,需要知道这条强边一共在几个环中,因为一个环代表加入的一条弱边,假如一条强边同时位于2个及以上的环中,那么中,那么切除这一条强边+一条弱边就无法将图分为两个部分
比如ab这条边,无论怎么切都不行
假如那条强边只在一个环中,显然切除该强边+形成环的弱边就可以了,答案+1
假如那条强边不在任何一个环中,那么切掉那条强边本身就可以使得树分为两部分,所以剩下那条弱边就随便切,答案+m(弱边数量)
因此只需要知道一条边在几个环中,累计所有边,就可以算出答案,需要一个算法可以维护强边在几个环中。
使用树上差分可以协助解决此问题。
类似于一维差分,每加入一条弱边就让弱边的两端的点+1(拿个数组存着,比如拿v数组)
然后两端的点的最近公共组先-2
比如:(红色则是该点的v数组)
对于任何一条边,其值等于其子树所有点的权值相加,则对于上面那张图:
ans=1+1+2+1=5
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#define ll long long
#define ull unsigned long long
using namespace std;
const int INF = 0x3f3f3f3f;
const int maxn = 1e5 + 7;
int Head[maxn], Nxt[maxn << 2], To[maxn << 2];
int v[maxn], ans, dp[maxn][30], mx_len;
int n, m, tot;
void add_edge(int fro, int to) {
Nxt[++tot] = Head[fro];
Head[fro] = tot;
To[tot] = to;
}
queue<int> Q;
int dep[maxn];
void bfs() {//bfs预处理,倍增求lca
Q.push(1);
dep[1] = 1;
while (!Q.empty()) {
int now = Q.front();
Q.pop();
for (int i = Head[now]; i; i = Nxt[i]) {
int &to = To[i];
if (dep[to]) continue;
dep[to] = dep[now] + 1;
dp[to][0] = now;
for (int j = 1; j <= mx_len; j++)
dp[to][j] = dp[dp[to][j - 1]][j - 1];
Q.push(to);
}
}
}
int lca(int x, int y) {
if (dep[x] > dep[y]) swap(x, y);
for (int i = mx_len; i>=0 ; i--) {
if (dep[dp[y][i]] >= dep[x]) y = dp[y][i];
}
if (x == y) return x;
for (int i = mx_len; i >= 0; i--) {
if (dp[x][i] != dp[y][i]) {
x = dp[x][i];
y = dp[y][i];
}
}
return dp[x][0];
}
int dfs(int s,int f) {
int sum = 0;
for (int i = Head[s]; i; i = Nxt[i]) {
int &to = To[i];
if (to == f) continue;
int res = dfs(to, s);
if (res == 1) ans++;
else if (res == 0) ans += m;
sum += res;
}
return sum + v[s];
}
int main() {
cin >> n >> m;
mx_len = (int)(log(n) / log(2)) + 1; //倍增求lca
for (int i = 1; i < n; i++) {
int fro, to;
scanf("%d %d", &fro, &to);
add_edge(fro, to);
add_edge(to, fro);
}
bfs();
for (int i = 1; i <= m; i++) {
int x, y;
scanf("%d %d", &x, &y);
v[x] += 1, v[y] += 1;
v[lca(x, y)] -= 2;
}
dfs(1, -1);
cout << ans << endl;
return 0;
}