题意:给你一颗有n个节点的树,每个节点都有一个值,问你是否能删去两条边使删后的三块值的和都相等。
分析:首先,要使三块和相等,总和sum%3一定是0的。这样我们先用一个dfs求出每个节点的子树的值的总和,再遍历一遍找值是sum/3的节点,因为删去的两条边的节点可能有两种情况,一种是一个是另一个的祖先,另一种是一个和另一个没有关系,所以我的做法是找到sum/3的节点,先放到一个vector里面,再判断其祖先节点是否有值为sum/3*2的节点如果有直接输出,如果没果没有,搜过的每个点都标记一下,防止重复搜而T掉,这样最后,如果还没输出,就说明不存在第一种情况,那么,我们遍历存下的vector,把每个点的祖先节点都找出来,存到set里,然后标记一下,最后,如果vector的size 减去 set的size是大于等于2的,那我们任意输出没有标记的两个就可以了,其他则是-1.
下面是巨丑的代码:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<queue>
#include<map>
#include<set>
#include<stack>
#include<cstring>
#include<string>
#include<vector>
//#include<unordered_set>
//#include<unordered_map>
#include<cmath>
using namespace std;
#define ll long long
typedef pair<int, int>pii;
typedef pair<ll, ll>pll;
const int MAXN = 1000005;
const int MAXM = 1000005;
const ll LINF = 0x3f3f3f3f3f3f3f3f;
const int INF = 0x3f3f3f3f;
const int MOD = 1000000007;
const double FINF = 1e30;
struct Edge
{
int v, w;
Edge(int _v = 0, int _w = 0) :v(_v), w(_w) {}
};
int cnt[MAXN], fa[MAXN];
bool vis2[MAXN], vis[MAXN], vis3[MAXN];
vector<Edge>vec[MAXN];
int dfs(int x)
{
int sum = 0;
for (int i = 0; i < vec[x].size(); ++i)
{
cnt[vec[x][i].v] += dfs(vec[x][i].v);
sum += cnt[vec[x][i].v];
}
return sum;
}
int ff = 0, pos = -1;
void find(int x, int f)
{
if (fa[x] == x)return;
if (ff == 1)return;
if (cnt[x] == f)
{
pos = x;
ff = 1;
return;
}
else vis2[x] = 1;
if (vis2[fa[x]] == 0)find(fa[x], f);
}
vector<int>uu;
void f(int x)
{
vis[x] = 1;
vis3[x] = 1;
if (fa[x] == x)return;
if(vis[fa[x]] == 0)f(fa[x]);
}
int main()
{
int n, a, b, root, sum = 0;
memset(cnt, 0, sizeof(cnt));
memset(vis, 0, sizeof(vis));
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
{
scanf("%d%d", &a, &b);
cnt[i] += b;
if (a != 0)
{
vec[a].push_back(Edge(i, b));
fa[i] = a;
}
sum += b;
if (a == 0)root = i, fa[i] = i;
}
cnt[root] += dfs(root);
if (sum % 3 != 0)
{
printf("-1\n");
}
else
{
uu.clear();
for (int i = 1; i <= n; ++i)
{
if (cnt[i] == sum / 3)
{
uu.push_back(i);
ff = 0;
if (vis2[fa[i]] == 0)find(fa[i], sum / 3 * 2);
//cout << i << " " << ff << endl;
if (ff == 1)
{
printf("%d %d\n", i, pos);
return 0;
}
}
}
//cout << 1 << endl;
for (int i = 0; i < uu.size(); ++i)
{
if (vis[fa[uu[i]]] == 0)
{
f(fa[uu[i]]);
}
}
int ans[10],pppp = 0;
for (int i = 0; i < uu.size(); ++i)
{
if (vis3[uu[i]] == 0)ans[pppp++] = uu[i];
}
if (pppp >= 2)printf("%d %d\n", ans[0], ans[1]);
else printf("-1\n");
}
}
/*
5
0 5
1 3
1 1
2 6
2 3
*/