传送门
题目大意
给定一棵
n
n
n个节点,
n
−
1
n-1
n−1 条边的树。你可以在每一条树上的边标上边权,使得:
每个边权都为 正整数;
这
n
−
1
n-1
n−1个边权的乘积等于
k
k
k;
边权为
1
1
1 的边的数量最少。
定义
f
(
u
,
v
)
f(u,v)
f(u,v) 表示节点
u
u
u到节点
v
v
v的简单路径经过的边权总和。你的任务是让
∑
i
=
1
n
−
1
∑
j
=
i
+
1
n
f
(
i
,
j
)
\sum_{i=1}^{n-1}\sum_{j=i+1}^{n}{f(i,j)}
∑i=1n−1∑j=i+1nf(i,j)最大。
最终答案可能很大,对
1
e
9
+
7
1e9+7
1e9+7取模即可。
k
k
k有可能很大,输入数据中包含了
m
m
m个质数
p
i
p_i
pi ,那么
k
k
k为这些质数的乘积。
输入格式
第一行,一个整数
t
(
1
≤
t
≤
100
)
t(1\leq t\leq 100)
t(1≤t≤100),表示多组测试数据个数。对于每一个测试数据:
第一行,一个整数
n
(
2
≤
n
≤
1
0
6
)
n(2\leq n\leq 10^6)
n(2≤n≤106),表示树上节点数;
第
2
2
2至
n
n
n行,每行两个整数
u
i
u_i
ui和
v
i
(
1
≤
u
i
,
v
i
≤
n
,
u
i
≠
v
i
)
(
1
≤
u
i
,
v
i
≤
n
,
u
i
≠
v
i
)
v_i(1\leq u_i,v_i \leq n,u_i\neq v_i)(1≤u_i,v_i ≤n,u_i\neq v_i)
vi(1≤ui,vi≤n,ui=vi)(1≤ui,vi≤n,ui=vi),描述了一条无向边;
第
n
+
1
n+1
n+1行,一个整数
m
(
1
≤
m
≤
6
×
1
0
4
)
m(1\leq m\leq 6\times 10^4)
m(1≤m≤6×104),表示
k
k
k分解成质因子的个数;
第
n
+
2
n+2
n+2行,
m
m
m个质数
p
i
(
2
≤
p
i
<
6
×
1
0
4
,
有
k
=
∏
i
=
1
m
p
i
p_i (2\leq p_i< 6\times 10^4,有 k=\prod\limits_{i=1}^m p_i
pi(2≤pi<6×104,有k=i=1∏mpi
数据保证所有的
n
n
n总和不超过
1
0
5
10^5
105,所有的
m
m
m总和不超过
6
×
1
0
4
6\times 10^4
6×104。
数据给出的边保证能够形成一棵树。
思路
考虑每条边对答案有多少次贡献。一条路径被经过,仅当起点是它左边的一个点,终点是它右边的一个点。贡献=左边点个数
×
\times
×右边点个数
×
\times
×边权。然后把最大的边权分给出现次数最多的就可以了。
当
m
<
n
−
1
m<n-1
m<n−1时,要填上
1
1
1直到
m
=
n
−
1
m=n-1
m=n−1。
当
m
>
n
−
1
m>n-1
m>n−1时,要把多余的最大的合并成一个新的数,值为数的积,直到
m
=
n
−
1
m=n-1
m=n−1。
代码
ll t,n,m,ans;
ll p[maxn],s[maxn];
vector<ll> e[maxn];
priority_queue<ll> q;
void init(){
memset(s,0,sizeof s);
memset(p,0,sizeof p);
ans=0;
for(int i=1;i<=n;i++){
e[i].clear();
}
while(!q.empty()) q.pop();
}
void dfs(ll x,ll f){
s[x]=1;
for(int i=0;i<e[x].size();i++){
ll y=e[x][i];
if(y==f) continue;
dfs(y,x);
s[x]+=s[y];
}
}
void dfs2(ll x,ll f){
for(int i=0;i<e[x].size();i++){
ll y=e[x][i];
if(y==f) continue;
q.push(s[y]*(n-s[y]));
dfs2(y,x);
}
}
int cmp(ll a,ll b){
return a>b;
}
int main(){
int t;
cin>>t;
while(t--){
init();
scanf("%lld",&n);ll x,y;
for(int i=1;i<=n-1;i++){
scanf("%lld%lld",&x,&y);
e[x].push_back(y);
e[y].push_back(x);
}
scanf("%lld",&m);
for(int i=1;i<=m;i++){
scanf("%lld",&p[i]);//质数
}
dfs(1,0);
dfs2(1,0);
sort(p+1,p+1+m,cmp);
if(m<=(n-1)){
ll top=1;
while(!q.empty()){
if(top<=m) ans+=q.top()*p[top];
else ans+=q.top();
ans%=mod;
q.pop();
top++;
}
}
else{
ll top=m-n+2;
ll sum=1;
for(int i=1;i<=(m-(n-1))+1;i++)
{
(sum*=p[i])%=mod;
}
p[top]=sum;
while(q.size())
{
(ans+=q.top()*p[top])%=mod;
q.pop();
top++;
}
}
printf("%lld\n",ans);
}
}