题目描述
思路:
直接树上差分,如果环的个数是0,贡献就是m,如果是1,那就是1
c o d e code code
#include<iostream>
#include<cstdio>
#include<vector>
#define ll long long
using namespace std;
const ll MAXN = 2e5 + 10;
ll n, m, f[MAXN][21], dep[MAXN], ans, d[MAXN];
vector<ll> b[MAXN];
void dfs(ll x, ll fa) {
dep[x] = dep[fa] + 1;
f[x][0] = fa;
for(ll i = 1; i <= 20; i ++) f[x][i] = f[f[x][i - 1]][i - 1];
for(ll i = 0; i < b[x].size(); i ++) {
ll y = b[x][i];
if(y == fa) continue;
dfs(y, x);
}
}
ll lca(ll x, ll y) {
if(dep[x] > dep[y]) swap(x, y);
ll k = dep[y] - dep[x], j = 20, t = 1 << 20;
while(k) {
if(k >= t) k -= t, y = f[y][j];
j --, t = 1 << j;
}
if(x == y) return x;
for(ll i = 20; i >= 0; i --)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
void dp(ll x, ll fa) {
for(ll i = 0; i < b[x].size(); i ++) {
ll y = b[x][i];
if(y == fa) continue;
dp(y, x);
if(d[y] == 1) ans ++;
if(d[y] == 0) ans += m;
d[x] += d[y];
}
}
int main() {
scanf("%lld%lld", &n, &m);
for(ll i = 1; i < n; i ++) {
ll x, y;
scanf("%lld%lld", &x, &y);
b[x].push_back(y);
b[y].push_back(x);
}
dfs(1, 0);
for(ll i = 1; i <= m; i ++) {
ll x, y;
scanf("%lld%lld", &x, &y);
d[x] ++, d[y] ++, d[lca(x, y)] -= 2;
}
dp(1, 0);
printf("%lld", ans);
return 0;
}