Description
Input
Output
Sample Input
4 4
1 3
2 1
4 3
4 3
4 1
1 2
Sample Output
68
Data Constraint
做法: 树形dp求两棵树每个点到i点的距离和,同时遍历求出最小距离和最后
答案等于a * n2 + a1 * n1 + n1 * n2 + sum1 + sum
代码如下:
#include <cstdio>
#pragma GCC optimize(2)
#include <iostream>
#include <cstring>
#define rep(i, a, b) for (int i = a; i <= b; i++)
#define eg(i, x) for (int i = ls[x]; i; i = e[i].next)
#define ll long long
#define N 300007
using namespace std;
struct arr
{
int to, next;
}e[N * 2];
ll f[N], ff[N], n1, n2, ls[N], g[N], sum, sum1, a, a1, cnt;
bool b[N];
int read()
{
int s = 0;
char ch = getchar();
while (ch < '0' || ch > '9') ch = getchar();
while (ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
return s;
}
void add(int x, int y)
{
e[++cnt].to = y;
e[cnt].next = ls[x];
ls[x] = cnt;
}
ll dfs(int x, int dep)
{
ll ans = dep;
eg(i, x)
if (!b[e[i].to])
{
b[e[i].to] = 1;
ans += dfs(e[i].to, dep + 1);
}
return ans;
}
ll dfs2(int x)
{
g[x] = 1;
eg(i, x)
if (!b[e[i].to])
{
b[e[i].to] = 1;
g[x] += dfs2(e[i].to);
}
return g[x];
}
void dp(int x)
{
eg(i, x)
if (!b[e[i].to])
{
b[e[i].to] = 1;
f[e[i].to] = f[x] + n1 - 2 * g[e[i].to];
dp(e[i].to);
}
}
void dp2(int x)
{
eg(i, x)
if (!b[e[i].to])
{
b[e[i].to] = 1;
ff[e[i].to] = ff[x] + n2 - 2 * g[e[i].to];
dp2(e[i].to);
}
}
int main()
{
freopen("unite.in", "r", stdin);
freopen("unite.out", "w", stdout);
n1 = read();
n2 = read();
int x, y;
rep(i, 1, n1 - 1)
{
x = read(), y = read();
add(x, y);
add(y, x);
}
b[1] = 1;
f[1] = dfs(1, 0);
memset(b, 0, sizeof(b));
b[1] = 1;
dfs2(1);
memset(b, 0, sizeof(b));
b[1] = 1;
dp(1);
a = 1e15, a1 = 1e15;
rep(i, 1, n1)
{
sum += f[i];
a = min(a, f[i]);
}
sum /= 2;
memset(ls, 0, sizeof(ls));
memset(e, 0, sizeof(e));
memset(b, 0, sizeof(b));
memset(g, 0, sizeof(g));
cnt = 0;
rep(i, 1, n2 - 1)
{
x = read(), y = read();
add(x, y);
add(y, x);
}
b[1] = 1;
ff[1] = dfs(1, 0);
memset(b, 0, sizeof(b));
b[1] = 1;
dfs2(1);
memset(b, 0, sizeof(b));
b[1] = 1;
dp2(1);
rep(i, 1, n2)
{
sum1 += ff[i];
a1 = min(a1, ff[i]);
}
sum1 /= 2;
cout << a * n2 + a1 * n1 + n1 * n2 + sum1 + sum;
}