Maximum Distributed Tree
题意
为一棵树的边添加权值 要求如下
- 权值大于 0 0 0
- 所有边权值之积等于 k k k
- 边权值中 1 1 1 的个数尽可能少
k k k 以质因数分解的形式给出
求 ∑ i = 1 n − 1 ∑ j = i + 1 n f ( i , j ) \sum\limits_{i=1}^{n-1} \sum\limits_{j=i+1}^n f(i,j) i=1∑n−1j=i+1∑nf(i,j)的值最大是多少 其中 f ( u , v ) f(u,v) f(u,v) 表示从 u u u 到 v v v 的简单路径上边的权值之和
思路
其实这题我们仔细想一下应该可以知道 使经过次数越多的边赋以越大的值可以使得最后结果越大 那么怎么求一条边经过了多少次呢
设 u u u 为 v v v 的父亲节点 s i z [ v ] siz[v] siz[v] 表示 以 v v v 为根的子树中节点的数量 那么 s i z [ v ] ∗ ( n − s i z [ v ] ) siz[v] * (n - siz[v]) siz[v]∗(n−siz[v]) 即为经过 u − v u-v u−v 这条边的次数
最后讨论一下边的条数和 k k k 的质因数数量的关系即可
若质因数数量小于边的数量 则多出来的边权值取1
否则 从大到小将质因数分配给每一个边
代码
#include<bits/stdc++.h>
#define INF 0x3f3f3f3f
#define mod 1000000007
#define endl '\n'
using namespace std;
typedef long long LL;
typedef pair<int, int>PII;
inline LL gcd(LL a, LL b) { return b ? gcd(b, a % b) : a; }
inline LL lowbit(LL x) { return x & -x; }
const int N = 100010;
LL n, m;
vector<LL>v[N];
LL prime[N];
LL siz[N], tot[N];
int cnt = 0;
bool cmp(LL a, LL b) {
return a > b;
}
void dfs(LL u, LL fa) {
siz[u] = 1;
for (LL i = 0; i < v[u].size(); ++i) {
LL j = v[u][i];
if (j != fa) {
dfs(j, u);
siz[u] += siz[j];
tot[++cnt] = siz[j] * (n - siz[j]);
}
}
}
void init() {
for (int i = 0; i <= n + 1; ++i) {
v[i].clear();
}
for (int i = 0; i <= m + 1; ++i) {
prime[i] = 0;
}
cnt = 0;
}
void solve() {
cin >> n;
for (LL i = 1; i <= n - 1; ++i) {
LL a, b; scanf("%lld%lld", &a, &b);
v[a].push_back(b);
v[b].push_back(a);
}
dfs(1, 0);
cin >> m;
for (LL i = 1; i <= m; ++i)
scanf("%lld", &prime[i]);
sort(prime + 1, prime + m + 1, cmp);
sort(tot + 1, tot + cnt + 1, cmp);
LL res = 0;
if (m < cnt) { //质因数少 降序排列 将大的质因数优先给次数多的边 剩余的边的用1
sort(prime + 1, prime + m + 1, cmp);
sort(tot + 1, tot + cnt + 1, cmp);
for (int i = 1; i <= m; ++i) {
res = (res + tot[i] * prime[i] % mod) % mod;
}
for(int i = m + 1;i <= cnt;++i){
res = (res + tot[i]) % mod;
}
}
else { //质因数多 将多出的质因数合并成一个给出现次数最多的边
sort(prime + 1, prime + m + 1);
sort(tot + 1, tot + 1 + cnt);
LL t = 1;
for (int i = cnt; i <= m; ++i) {
t = t * prime[i] % mod;
}
prime[cnt] = t;
for (int i = 1; i <= cnt; ++i) {
res = (res + prime[i] * tot[i] % mod) % mod;
}
}
cout << res % mod << endl;
init();
}
int main() {
int t; cin >> t;
while(t--)
solve();
return 0;
}