题目描述
无向连通图 G 有 n 个点,n-1条边。点从 1 到 n依次编号,编号为 ii 的点的权值为
W
i
W_i
Wi
,每条边的长度均为 1。图上两点 (u, v)(u,v) 的距离定义为 uu 点到 vv 点的最短距离。对于图 G 上的点对 (u,v),若它们的距离为 2,则它们之间会产生
W
v
×
W
u
W_v \times W_u
Wv×Wu的联合权值。
请问图 G 上所有可产生联合权值的有序点对中,联合权值最大的是多少?所有联合权值之和是多少?
输入格式
第一行包含 1个整数 n。
接下来 n-1 行,每行包含 22 个用空格隔开的正整数 u,v表示编号为 u和编号为 v的点之间有边相连。
最后 1行,包含 n个正整数,每两个正整数之间用一个空格隔开,其中第 i 个整数表示图 G 上编号为 i 的点的权值为
W
i
W_i
Wi
输出格式
输出共 11 行,包含 22 个整数,之间用一个空格隔开,依次为图 GG 上联合权值的最大值和所有联合权值之和。由于所有联合权值之和可能很大,输出它时要对1000710007取余。
输入输出样例
输入
5
1 2
2 3
3 4
4 5
1 5 2 3 10
输出
20 74
说明/提示
本例输入的图如上所示,距离为2 的有序点对有( 1,3)、( 2,4) 、( 3,1)、( 3,5) 、( 4,2)、( 5,3)。
其联合权值分别为2 、15、2 、20、15、20。其中最大的是20,总和为74。
【数据说明】
对于30%的数据,
1
<
n
≤
1001
<
n
≤
100
1 < n \leq 1001<n≤100
1<n≤1001<n≤100;
对于60%的数据,
1
<
n
≤
2000
1 < n \leq 2000
1<n≤2000
对于100%的数据,
1
<
n
≤
200000
,
0
<
W
i
≤
10000
。
1 < n \leq 200000, 0 < W_i \leq 10000。
1<n≤200000,0<Wi≤10000。
保证一定存在可产生联合权值的有序点对。
求和的话对于一个点对(a, b)我们要求的是2ab
我们可推出公式2ab = (a,b)2 - a2 - b2
于是我们拓展出来多个点对的求和
我们枚举中间点,每两个子节点都是一个点对
在求和时顺便求出最大值,时间复杂度是O(n)的
#include<bits/stdc++.h>
#define ll long long
#define MAXN 500010
#define N 201
#define INF 0x3f3f3f3f
#define gtc() getchar()
using namespace std;
template <class T>
inline void read(T &s){
s = 0; T w = 1, ch = gtc();
while(!isdigit(ch)){if(ch == '-') w = -1; ch = gtc();}
while(isdigit(ch)){s = s * 10 + ch - '0'; ch = gtc();}
s *= w;
}
template <class T>
inline void write(T x){
if(x < 0) putchar('-'), x = -x;
if(x > 9) write(x/10);
putchar(x % 10 + '0');
}
const int mod = 10007;
struct node{
int y, ne;
}e[MAXN];
int lin[MAXN], len = 0;
inline void add(int x, int y){
e[++len].y = y, e[len].ne = lin[x], lin[x] = len;
}
int n;
int d[MAXN];
int sum[MAXN];
int all = 0, ans = 0;
void bfs(int x){
int fmx = 0, smx = 0;
int sum1 = 0, sum2 = 0;
for(int i = lin[x]; i; i = e[i].ne){
int y = e[i].y;
// printf("%d ", y);
sum1 = (sum1 + d[y]) % mod;
sum2 = (sum2 + d[y] * d[y]) % mod;
if(d[y] > fmx){
smx = fmx, fmx = d[y];
}
else if(d[y] > smx) smx = d[y];
}
// puts("");
ans = max(ans, fmx * smx);
sum1 = sum1 * sum1 % mod;
all += (sum1 - sum2 + mod) % mod;
}
int main()
{
read(n);
int x, y;
for(int i = 1; i < n; ++i){
read(x), read(y);
add(x, y), add(y, x);
}
for(int i = 1; i <= n; ++i) read(d[i]);
for(int i = 1; i <= n; ++i) bfs(i);
all %= mod;
cout << ans << ' ' << all << endl;
return 0;
}