题目
传说中的暗之连锁被人们称为 Dark。
Dark 是人类内心的黑暗的产物,古今中外的勇者们都试图打倒它。
经过研究,你发现 Dark 呈现无向图的结构,图中有 N N N 个节点和两类边,一类边被称为主要边,而另一类被称为附加边。
Dark 有 N – 1 N – 1 N–1 条主要边,并且 Dark 的任意两个节点之间都存在一条只由主要边构成的路径。
另外,Dark 还有 M 条附加边。
你的任务是把 Dark 斩为不连通的两部分。
一开始 Dark 的附加边都处于无敌状态,你只能选择一条主要边切断。
一旦你切断了一条主要边,Dark 就会进入防御模式,主要边会变为无敌的而附加边可以被切断。
但是你的能力只能再切断 Dark 的一条附加边。
现在你想要知道,一共有多少种方案可以击败 Dark。
注意,就算你第一步切断主要边之后就已经把 Dark 斩为两截,你也需要切断一条附加边才算击败了 Dark。
输入格式
第一行包含两个整数
N
N
N 和
M
M
M。
之后 N – 1 N – 1 N–1 行,每行包括两个整数 A 和 B,表示 A 和 B 之间有一条主要边。
之后 M M M 行以同样的格式给出附加边。
输出格式
输出一个整数表示答案。
数据范围
N
≤
100000
,
M
≤
200000
N≤100000,M≤200000
N≤100000,M≤200000,数据保证答案不超过
2
31
−
1
2^{31}−1
231−1
输入样例:
4 1
1 2
2 3
1 4
3 4
输出样例:
3
题解
- 我们可以将每条附加边 ( x , y ) (x,y) (x,y)定义为将树上 ( x , y ) (x,y) (x,y)之间的路径覆盖了一次。
- 我们只要统计每条边被覆盖了几次即可计算答案。
- 覆盖 0 0 0次答案加上 m m m
- 覆盖 1 1 1次答案加上 1 1 1
- 覆盖两次或两次以上不用计算答案
- 计算覆盖次数只需将 c [ x ] + + , c [ y ] + + , c [ l c a ( x , y ) ] − 2 c[x]++, c[y]++, c[lca(x,y)]-2 c[x]++,c[y]++,c[lca(x,y)]−2,最后 d f s dfs dfs求和即可
code
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 100000 + 100;
const int maxm = 200000 + 100;
template <class T>
inline void read(T &s) {
s = 0; T w = 1, ch = getchar();
while (!isdigit(ch)) { if (ch == '-') w = -1; ch = getchar(); }
while (isdigit(ch)) { s = (s << 1) + (s << 3) + (ch ^ 48); ch = getchar(); }
s *= w;
}
int n, m, tot;
int lin[maxn], dep[maxn], f[maxn];
int fa[maxn][21];
struct node {
int next, to;
} e[maxn << 1];
inline void add(int from, int to) {
e[++tot].to = to;
e[tot].next = lin[from];
lin[from] = tot;
}
void dfs(int u, int fat) {
dep[u] = dep[fat] + 1;
fa[u][0] = fat;
for (int i = 1; i <= 20; ++i)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int i = lin[u]; i; i = e[i].next) {
int v = e[i].to;
if (v != fat) dfs(v, u);
}
}
int lca(int x, int y) {
if (dep[x] < dep[y]) x ^= y ^= x ^= y;
for (int i = 20; i >= 0; --i)
if (dep[fa[x][i]] >= dep[y]) x = fa[x][i];
if (x == y) return x;
for (int i = 20; i >= 0; --i)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
void get_f(int u, int fat) {
for (int i = lin[u]; i; i = e[i].next) {
int v = e[i].to;
if (v == fat) continue;
get_f(v, u);
f[u] += f[v];
}
}
int main() {
read(n), read(m);
for (int i = 1; i < n; ++i) {
int x, y; read(x), read(y);
add(x, y); add(y, x);
}
dfs(1, 0);
for (int i = 1; i <= m; ++i) {
int x, y; read(x), read(y);
f[x]++; f[y]++;
f[lca(x, y)] -= 2;
}
get_f(1, 0);
LL ans = 0ll;
for (int i = 2; i <= n; ++i) {
ans += (f[i] == 0 ? m : f[i] == 1 ? 1 : 0);
}
printf("%lld\n", ans);
return 0;
}