Description
Yixght is a manager of the company called SzqNetwork(SN). Now she's very worried because she has just received a bad news which denotes that DxtNetwork(DN), the SN's business rival, intents to attack the network of SN. More unfortunately, the original network of SN is so weak that we can just treat it as a tree. Formally, there are N nodes in SN's network, N-1 bidirectional channels to connect the nodes, and there always exists a route from any node to another. In order to protect the network from the attack, Yixght builds M new bidirectional channels between some of the nodes.
As the DN's best hacker, you can exactly destory two channels, one in the original network and the other among the M new channels. Now your higher-up wants to know how many ways you can divide the network of SN into at least two parts.
Input
The first line of the input file contains two integers: N (1 ≤ N ≤ 100 000), M (1 ≤ M ≤ 100 000) — the number of the nodes and the number of the new channels.
Following N-1 lines represent the channels in the original network of SN, each pair (a,b) denote that there is a channel between node a and nodeb.
Following M lines represent the new channels in the network, each pair (a,b) denote that a new channel between node a and node b is added to the network of SN.
Output
Output a single integer — the number of ways to divide the network into at least two parts.
Sample Input
4 1 1 2 2 3 1 4 3 4
Sample Output
3
先给出一棵树,然后往里面加边,最后你可以去除一条原本树上的边,一条后来加的边,问能把图分成至少两块的方案有几种。
如果去除原来树上的边就可以分两块,那么就有m种方案,否则要两个合起来才可以,就只有1种方案。
然后考虑怎么删边,可以知道,如果在一棵树上加边,就会形成一个环,可以认为原本链上的边都多了一条,这样就可变成覆盖的问题了,我们先给树定一个根节点,假设为1,然后考虑每次加边会覆盖那些链,这里用lca最近公共祖先就可以了。给每个点一个cnt的值,统计边被覆盖的次数,然后一次dfs可以知道每条边被覆盖的情况,然后按照上面说的统计就行了。
#include<set>
#include<map>
#include<ctime>
#include<cmath>
#include<stack>
#include<queue>
#include<bitset>
#include<cstdio>
#include<string>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<functional>
#define rep(i,j,k) for (int i = j; i <= k; i++)
#define per(i,j,k) for (int i = j; i >= k; i--)
using namespace std;
typedef long long LL;
const int low(int x) { return x&-x; }
const double eps = 1e-8;
const int INF = 0x7FFFFFFF;
const int mod = 1e9 + 7;
const int N = 2e5 + 10;
int T, n, m;
int ft[N], nt[N], u[N], sz, x, y;
int lca[N][20], dp[N], cnt[N];
LL ans;
void dfs(int x, int fa)
{
dp[x] = dp[fa] + 1;
lca[x][0] = fa;
for (int i = 1; (1 << i) <=dp[x]; i++)
{
lca[x][i] = lca[lca[x][i - 1]][i - 1];
}
for (int i = ft[x]; i != -1; i = nt[i])
{
if (u[i] == fa) continue;
dfs(u[i], x);
}
}
int getlca(int a, int b)
{
int i, j;
if (dp[a] < dp[b]) swap(a, b);
for (i = 0; (1 << i) <= dp[a]; i++);
i--;
for (j = i; j >= 0; j--)
if (dp[a] - (1 << j) >= dp[b])
a = lca[a][j];
if (a == b)return a;
for (j = i; j >= 0; j--)
{
if (lca[a][j] && lca[a][j] != lca[b][j])
{
a = lca[a][j];
b = lca[b][j];
}
}
return lca[a][0];
}
void get(int x, int fa)
{
for (int i = ft[x]; i != -1; i = nt[i])
{
if (u[i] == fa) continue;
get(u[i], x);
cnt[x] += cnt[u[i]];
if (cnt[u[i]] == 0) ans += m;
if (cnt[u[i]] == 1) ans++;
}
}
int main()
{
while (scanf("%d%d", &n, &m) != EOF)
{
rep(i, 1, n) ft[i] = -1, cnt[i] = 0;
dp[0] = sz = 0;
rep(i, 1, n - 1)
{
scanf("%d%d", &x, &y);
u[sz] = y; nt[sz] = ft[x]; ft[x] = sz++;
u[sz] = x; nt[sz] = ft[y]; ft[y] = sz++;
}
dfs(1, 0);
rep(i, 1, m)
{
scanf("%d%d", &x, &y);
int k = getlca(x, y);
if (k == x || k == y) cnt[k]--, cnt[x^y^k]++;
else cnt[x]++, cnt[y]++, cnt[k] -= 2;
}
ans = 0;
get(1, 0);
printf("%lld\n", ans);
}
return 0;
}