深搜
本题要点:
1、首先 n <= 2 * 10^5, 所以时间复杂度就是 O(n), 最多是 O(n * logn)。此时,用 链式前向星来存图。
2、假设当前节点是 x,父节点是 fa, x 有若干个 孩子节点 y1, y2, … , yk。
那么 ,以x节点为中间节点,并且节点之间的距离的节点对有 (fa, y1), (fa, y2), …, (fa, yk),
(y1, fa), (y1, y2), …, (y1, yk) ,…
观察到,这些节点之间,都互相组合了一遍,相当于
(fa + y1 + y2 + ... + yk) * (fa + y1 + y2 + ... + yk) - fa * fa - y1 * y1 - y2 * y2 - ... - yk * yk;
(假设这个叫做 大S)
3、 深搜:
sum, 记录这些点的 w 值的和,
squre, 记录这些点 的w值的平方和。
int dfs(int x), 函数返回的就是,以x节点为根几点, 全部距离是2的 点的权值 w 乘积的和。
dfs(x) = dfs(y1) + dfs(y2) + ... + dfs(yk) + 大S
4、至于,求哪两个 w 乘积最大,就在每个 dfs 函数里面,连个变量 max_w = -1, second_w = -2,
记录最大,次大的w值即可。
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int MaxN = 2e5 + 10, mod = 10007, MaxM = 4e5 + 10;
int n, tot, max_ww;
int head[MaxN], ver[MaxM], Next[MaxM], deg[MaxN];
int w[MaxN];
bool vis[MaxN];
void add(int x, int y)
{
ver[++tot] = y, Next[tot] = head[x], head[x] = tot;
}
int dfs(int x, int fa)
{
vis[x] = 1;
int sum = 0, squre = 0, ans = 0;
int max_w = -1, second_w = -2;
if(fa != -1)
{
sum = w[fa], squre = (-w[fa] * w[fa]) % mod;
second_w = max_w, max_w = w[fa];
}
for(int i = head[x]; i; i = Next[i])
{
int y = ver[i];
if(vis[y]) continue;
if(w[y] > max_w)
{
second_w = max_w, max_w = w[y];
}else if(w[y] > second_w){
second_w = w[y];
}
sum = (sum + w[y]) % mod;
squre = (squre - w[y] * w[y]) % mod;
int tmp = dfs(y, x);
ans = (ans + tmp) % mod;
}
if(max_w != -1 && second_w != -2)
max_ww = max(max_ww, max_w * second_w);
if(deg[x] > 1)
{
sum = (sum * sum) % mod;
ans = ((ans + sum + squre) % mod + mod) % mod;
}
return ans;
}
int main()
{
int x, y;
scanf("%d", &n);
for(int i = 1; i < n; ++i)
{
scanf("%d%d", &x, &y);
add(x, y);
add(y, x);
deg[x]++, deg[y]++;
}
for(int i = 1; i <= n; ++i)
{
scanf("%d", &w[i]);
}
max_ww = -3;
printf("%d %d\n", max_ww, dfs(1, -1));
return 0;
}
/*
5
1 2
2 3
3 4
4 5
1 5 2 3 10
*/
/*
20 74
*/