有两颗无根树Tree1与Tree2,这两棵树各有N个节点(2<=N<=4000)。每棵树中,N个节点都被标为0,1,2,..,N-1,这N个序号,不保证两颗树是同构的。
定义一个关于这两棵树的函数S(e1,e2),其中e1为Tree1的一条边,e2为Tree2的一条边,定义如下:
(1)如果将边e1从Tree1中移除(即Tree1在e1处断开),那么Tree1将分裂为两棵小树A1与B1;
(2)如果将边e2从Tree2中移除(即Tree2在e2处断开),那么Tree2将分裂为两棵小树A2与B2;
(3)令Set(Tree)表示Tree中所有节点的序号构成的集合,且令Fun(S1,S2)为集合S1与S2交集的大小(交集元素个数,空集时为0),即|S1 ∩ S2|,例如Fun({1,2,3,4},{7,3,2,8,6}) = 2;
(4)S(e1,e2) = max{ Fun(Set(A1),Set(A2)) , Fun(Set(A1),Set(B2)) , Fun(Set(B1),Set(A2)) ,
Fun(Set(B1),Set(B2)) }.其中 max{...} 为集合{...}中元素的最大值。
简单说,S(e1,e2)为Fun(X,Y)能取到的最大值,其中X=A1或B1,Y=A2或B2.
在这个问题中,你需要求对于所以边对(e1,e2),共计(N-1)*(N-1)对情况下,S(e1,e2)平方的和。即求如下公式的值。
Input
第一行一个整数N,其中2<=N<=4000 之后N-1行,包含Tree1的边的信息 每行两个数xi,yi,表示Tree1中第i条边(xi,yi),其中0<=xi,yi<N 之后N-1行,包含Tree2的边的信息 每行两个数aj,bj,表示Tree2中第j条边(aj,bj),其中0<=aj,bj<N
Output
一个整数,即问题中公式的结果。
Input示例
5 0 2 1 4 2 1 3 0 0 1 1 4 2 4 3 4
Output示例
111
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <tuple>
using namespace std;
typedef long long ll;
typedef tuple<int, int> tii;
const int MAXN = 4e3 + 10;
int N;
int n;
ll res = 0;
vector<int> tree[MAXN];
vector<int> tree_[MAXN];
tii tmp;
struct edge
{
int u;
int v;
} tr[MAXN], tr_[MAXN];
int vis[MAXN];
tii dfs_(int last, int root)
{
int cnt = 0, m = 0;
if (vis[root])
{
cnt++;
}
m++;
int a, b;
for (int i = 0; i < tree_[root].size(); i++)
{
if (tree_[root][i] != last)
{
tmp = dfs_(root, tree_[root][i]);
tie(a, b) = tmp;
cnt += a;
m += b;
}
}
if (m == N)
{
return make_tuple(cnt, m);
}
ll temp1 = max((ll)cnt * cnt, (ll)(n - cnt) * (n - cnt));
ll temp2 = max((ll)(m - cnt) * (m - cnt), (ll)(N - m - n + cnt) * (N - m - n + cnt));
res += max(temp1, temp2);
return make_tuple(cnt, m);
}
void dfs(int last, int root)
{
vis[root] = 1;
n++;
for (int i = 0; i < tree[root].size(); i++)
{
if (tree[root][i] != last)
{
dfs(root, tree[root][i]);
}
}
}
int main(int argc, const char * argv[])
{
cin >> N;
int u, v;
for (int i = 1; i < N; i++)
{
cin >> u >> v;
u++, v++;
tree[u].push_back(v);
tree[v].push_back(u);
tr[i].u = u;
tr[i].v = v;
}
for (int i = 1; i < N; i++)
{
cin >> u >> v;
u++, v++;
tree_[u].push_back(v);
tree_[v].push_back(u);
tr_[i].u = u;
tr_[i].v = v;
}
for (int i = 1; i < N; i++)
{
n = 0;
memset(vis, 0, sizeof(vis));
dfs(tr[i].v, tr[i].u);
dfs_(-1, 1);
}
cout << res << '\n';
return 0;
}