题目:http://community.topcoder.com/stat?c=problem_statement&pm=13086&rd=15854
看到tree应当想到dfs,用两个dfs来标记节点。
代码如下:
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include <iostream>
#include <sstream>
#include <iomanip>
#include <bitset>
#include <string>
#include <vector>
#include <stack>
#include <deque>
#include <queue>
#include <set>
#include <map>
#include <cstdio>
#include <cstdlib>
#include <cctype>
#include <cmath>
#include <cstring>
#include <ctime>
#include <climits>
using namespace std;
#define CHECKTIME() printf("%.2lf\n", (double)clock() / CLOCKS_PER_SEC)
typedef pair<int, int> pii;
typedef long long llong;
typedef pair<llong, llong> pll;
#define mkp make_pair
#define FOREACH(it, X) for(__typeof((X).begin()) it = (X).begin(); it != (X).end(); ++it)
/*************** Program Begin **********************/
class TreesAnalysis {
public:
vector <int> T1[4001], T2[4001];
int cntA[4001], cntB[4001];
bool visited[4001], isA[4001];
void dfs(int v, int v1, int v2)
{
visited[v] = true;
isA[v] = true;
for (int i = 0; i < T1[v].size(); i++) {
int a = T1[v][i];
if (visited[a]) {
continue;
}
if ((v == v1 && a == v2) ||
(v == v2 && a == v1) ) {
continue;
}
dfs(a, v1, v2);
}
}
void dfs2(int v)
{
visited[v] = true;
if (isA[v]) {
++cntA[v];
} else {
++cntB[v];
}
for (int i = 0; i < T2[v].size(); i++) {
int a = T2[v][i];
if (visited[a]) {
continue;
}
dfs2(a);
cntA[v] += cntA[a];
cntB[v] += cntB[a];
}
}
long long treeSimilarity(vector <int> tree1, vector <int> tree2) {
long long res = 0;
int n = tree1.size() + 1;
for (int i = 0; i < n - 1; i++) {
T1[i].push_back(tree1[i]);
T1[tree1[i]].push_back(i);
T2[i].push_back(tree2[i]);
T2[tree2[i]].push_back(i);
}
for (int i = 0; i < n - 1; i++) {
int v1 = i, v2 = tree1[i];
memset(visited, 0, sizeof(visited));
memset(isA, 0, sizeof(isA));
dfs(v1, v1, v2);
int sumA = accumulate(isA, isA + n, 0);
memset(visited, 0, sizeof(visited));
memset(cntA, 0, sizeof(cntA));
memset(cntB, 0, sizeof(cntB));
dfs2(0);
for (int j = 0; j < n - 1; j++) {
int v3 = j, v4 = tree2[j];
long long x = 0;
long long a, b, c, d;
if (cntA[v3] + cntB[v3] > cntA[v4] + cntB[v4]) {
a = cntA[v4];
b = cntB[v4];
c = sumA - cntA[v4];
d = n - a - b - c;
} else {
a = cntA[v3];
b = cntB[v3];
c = sumA - cntA[v3];
d = n - a - b - c;
}
x = max(a, b);
x = max(x, c);
x = max(x, d);
res += x * x;
}
}
return res;
}
};
/************** Program End ************************/