题目
题意:
给你一棵数,现在你要将这棵树上附上权值并且满足一下条件:
- 每一个权值都是正整数。
- n − 1 n-1 n−1个数的乘积要等于 k k k
- 第三个条件看不懂 Q A Q QAQ QAQ,但是应该不太重要…
现在设置 f ( u , v ) f(u,v) f(u,v)为节点 u → v u\to v u→v的权值,需要你求出最大的 ∑ i = 1 n − 1 ∑ j = i + 1 n f ( i , j ) \sum^{n-1}_{i=1}\sum^n_{j=i+1}f(i,j) i=1∑n−1j=i+1∑nf(i,j)
思路:
我们先通过
d
f
s
dfs
dfs树求出每一条边会经过多少次,设置成
a
[
c
n
t
]
=
s
z
v
∗
(
n
−
s
z
v
)
a[cnt]=sz_v*(n-sz_v)
a[cnt]=szv∗(n−szv),
s
z
u
sz_u
szu代表
d
f
s
(
x
)
dfs(x)
dfs(x)中以
x
x
x建树,然后此时的
u
u
u为根节点,然后往下有
s
z
u
sz_u
szu个节点。
然后我们给每一条边赋值上权值,权值的话,就是后面的
m
m
m个数,如果
m
m
m个数超过
n
−
1
n-1
n−1条边的话,那么最大的那个权值要乘上后面的值,因为最后
∏
i
=
0
n
−
2
f
(
u
,
v
)
=
k
\prod^{n-2}_{i=0}f(u,v)=k
∏i=0n−2f(u,v)=k,如果小于的话,那么要将后面的
f
(
u
,
v
)
=
1
f(u,v)=1
f(u,v)=1,这样最后的乘积也是
k
k
k。
我这里分析两个坑点,没太注意到细节,然后
w
a
wa
wa了好多次。
- 我们在求 s z sz sz数组的时候,不能当场取模,不然之后的排序就会出现错误。
- 当 m m m超过 n − 1 n-1 n−1的时候,我们必须用最大的那个质数去乘,因为 s z ∗ f ( u , v ) sz*f(u,v) sz∗f(u,v),当 s z sz sz最大的情况下,那么 f ( u , v ) f(u,v) f(u,v)也要最大。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <vector>
#include <string>
#include <cmath>
#include <set>
#include <map>
#include <deque>
#include <stack>
#include <cctype>
using namespace std;
typedef long long ll;
typedef vector<int> veci;
typedef vector<ll> vecl;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
template <class T>
inline void read(T &ret) {
char c;
int sgn;
if (c = getchar(), c == EOF) return ;
while (c != '-' && (c < '0' || c > '9')) c = getchar();
sgn = (c == '-') ? -1:1;
ret = (c == '-') ? 0:(c - '0');
while (c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c - '0');
ret *= sgn;
return ;
}
inline void outi(int x) {if (x > 9) outi(x / 10);putchar(x % 10 + '0');}
inline void outl(ll x) {if (x > 9) outl(x / 10);putchar(x % 10 + '0');}
const int maxn = 100010;
const int mod = 1e9 + 7;
ll sz[maxn], a[maxn], p[maxn];
vecl edge[maxn];
int cntt = 0, n, m;
void dfs(int u, int pre) {
sz[u] = 1;
for (int i = 0; i < edge[u].size(); i++) {
int v = edge[u][i];
if (v != pre) {
dfs(v, u);
sz[u] += sz[v];
a[cntt++] = (n - sz[v]) * sz[v];
}
}
}
int main() {
int t; read(t); while (t--) {
read(n);
for (int i = 0; i <= n; i++) edge[i].clear();
for (int i = 0; i < n - 1; i++) {
int u, v;
read(u), read(v);
edge[u].push_back(v);
edge[v].push_back(u);
}
cntt = 0;
dfs(1, -1);
sort(a, a + cntt);
read(m);
for (int i = 0; i < m; i++) read(p[i]);
sort(p, p + m);
if (cntt > m) {
for (int i = m; i < cntt; i++) p[i] = 1;
} else {
for (int i = cntt; i < m; i++) p[cntt - 1] = p[cntt - 1] * p[i] % mod;
}
ll ans = 0;
sort(p, p + cntt);
for (int i = cntt - 1; i >= 0; i--) {
ans = (ans + a[i] % mod * p[i]) % mod;
}
printf("%lld\n", ans);
}
return 0;
}