D. Maximum Distributed Tree
You are given a tree that consists of n nodes. You should label each of its n−1 edges with an integer in such way that satisfies the following conditions:
each integer must be greater than 0;
the product of all n−1 numbers should be equal to k;
the number of 1-s among all n−1 integers must be minimum possible.
Let’s define f(u,v) as the sum of the numbers on the simple path from node u to node v. Also, let ∑i=1n−1∑j=i+1nf(i,j) be a distribution index of the tree.
Find the maximum possible distribution index you can get. Since answer can be too large, print it modulo 109+7.
In this problem, since the number k can be large, the result of the prime factorization of k is given instead.
Input
The first line contains one integer t (1≤t≤100) — the number of test cases.
The first line of each test case contains a single integer n (2≤n≤105) — the number of nodes in the tree.
Each of the next n−1 lines describes an edge: the i-th line contains two integers ui and vi (1≤ui,vi≤n; ui≠vi) — indices of vertices connected by the i-th edge.
Next line contains a single integer m (1≤m≤6⋅104) — the number of prime factors of k.
Next line contains m prime numbers p1,p2,…,pm (2≤pi<6⋅104) such that k=p1⋅p2⋅…⋅pm.
It is guaranteed that the sum of n over all test cases doesn’t exceed 105, the sum of m over all test cases doesn’t exceed 6⋅104, and the given edges for each test cases form a tree.
Output
Print the maximum distribution index you can get. Since answer can be too large, print it modulo 109+7.
Example
input
3
4
1 2
2 3
3 4
2
2 2
4
3 4
1 3
3 2
2
3 2
7
6 1
2 3
4 6
7 3
5 1
3 6
4
7 5 13 3
output
17
18
286
解题思路:
思路其实很简单,要求所有路径上的累加和,肯定把最大值放在权重最大的边上。
权重等于一条边 两侧点的数量的乘积。
然后排序乘就是了。
先排序再取模!(从比赛1小时wa到第二天)
AC代码:
#include <cstdio>
#include <vector>
#include <queue>
#include <cstring>
#include <cmath>
#include <map>
#include <set>
#include <string>
#include <iostream>
#include <algorithm>
#include <iomanip>
using namespace std;
#define sd(n) scanf("%d",&n)
#define sdd(n,m) scanf("%d%d",&n,&m)
#define sddd(n,m,k) scanf("%d%d%d",&n,&m,&k)
#define pd(n) printf("%d\n", n)
#define pc(n) printf("%c", n)
#define pdd(n,m) printf("%d %d", n, m)
#define pld(n) printf("%lld\n", n)
#define pldd(n,m) printf("%lld %lld\n", n, m)
#define sld(n) scanf("%lld",&n)
#define sldd(n,m) scanf("%lld%lld",&n,&m)
#define slddd(n,m,k) scanf("%lld%lld%lld",&n,&m,&k)
#define sf(n) scanf("%lf",&n)
#define sc(n) scanf("%c",&n)
#define sff(n,m) scanf("%lf%lf",&n,&m)
#define sfff(n,m,k) scanf("%lf%lf%lf",&n,&m,&k)
#define ss(str) scanf("%s",str)
#define rep(i,a,n) for(int i=a;i<=n;i++)
#define per(i,a,n) for(int i=n;i>=a;i--)
#define mem(a,n) memset(a, n, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define pb push_back
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define mod(x) ((x)%MOD)
#define gcd(a,b) __gcd(a,b)
#define lowbit(x) (x&-x)
#define pii map<int,int>
#define mk make_pair
#define rtl rt<<1
#define rtr rt<<1|1
#define int long long
typedef pair<int,int> PII;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
const int MOD = 1e9 + 7;
const double eps = 1e-9;
const ll INF = 0x3f3f3f3f3f3f3f3fll;
const int inf = 0x3f3f3f3f;
inline int read()
{
int ret = 0, sgn = 1;
char ch = getchar();
while(ch < '0' || ch > '9')
{
if(ch == '-')
sgn = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
ret = ret*10 + ch - '0';
ch = getchar();
}
return ret*sgn;
}
inline void Out(int a){if(a>9) Out(a/10);putchar(a%10+'0');}
int qpow(int m, int k, int mod){int res=1,t=m;while(k){if(k&1)res=res*t%mod;t=t*t%mod;k>>=1;}return res;}
ll gcd(ll a,ll b){return b==0?a : gcd(b,a%b);}
ll lcm(ll a,ll b){return a*b/gcd(a,b);}
ll inv(ll x,ll m){return qpow(x,m-2,m)%m;}
const int N = 5e6+10;
int n,m,q;
int a[N],b[N];
int cnt[N];
int val[N];
int ans = 0;
struct node{
vector<int> to;
}edge[N];
int dfs1(int x,int par)
{
int nn = edge[x].to.size();
int tmp = 1;
for(int i = 0 ; i < nn ; i ++)
{
int xx = edge[x].to[i];
if(xx != par)
tmp += dfs1(xx,x);
}
cnt[x] = tmp;
val[x] = cnt[x]*(n-cnt[x]);
return tmp;
}
signed main()
{
signed t;
cin>>t;
while(t--)
{
cin>>n;
for(int i = 0 ; i < n ; i ++)
{
edge[i].to.clear();
cnt[i] = 0;
b[i] = 0;
}
for(int i = 0 ; i < n-1 ; i ++)
{
int x,y;
cin>>x>>y;
x -= 1,y -= 1;
edge[x].to.push_back(y);
edge[y].to.push_back(x);
}
cin>>m;
for(int i = 0 ; i < m ; i ++)
cin>>b[i];
for(int i = m ; i < n-1 ; i ++) // m < n-1 后面的补1
b[i] = 1;
sort(b,b+max(m,n-1));
for(int i = n-1 ; i < m ; i ++) // m > n-1 后面的累乘到 b[n-2]
b[n-2] = (b[n-2]%MOD*b[i]%MOD)%MOD;
for(int i = 0 ; i < n ; i ++)
{
if(edge[i].to.size() == 1)
{
dfs1(i,-1);
break;
}
}
ans = 0;
sort(val,val+n,greater<int>());
for(int i = 0 ; i < n-1 ; i ++)
{
//cout<<val[i]<<" "<<b[i]<<endl;
ans = (val[i]%MOD*b[n-2-i]%MOD+ans)%MOD;
}
cout<<ans<<endl;
}
return 0;
}
贴一个错误示范
#include <cstdio>
#include <vector>
#include <queue>
#include <cstring>
#include <cmath>
#include <map>
#include <set>
#include <string>
#include <iostream>
#include <algorithm>
#include <iomanip>
using namespace std;
#define sd(n) scanf("%d",&n)
#define sdd(n,m) scanf("%d%d",&n,&m)
#define sddd(n,m,k) scanf("%d%d%d",&n,&m,&k)
#define pd(n) printf("%d\n", n)
#define pc(n) printf("%c", n)
#define pdd(n,m) printf("%d %d", n, m)
#define pld(n) printf("%lld\n", n)
#define pldd(n,m) printf("%lld %lld\n", n, m)
#define sld(n) scanf("%lld",&n)
#define sldd(n,m) scanf("%lld%lld",&n,&m)
#define slddd(n,m,k) scanf("%lld%lld%lld",&n,&m,&k)
#define sf(n) scanf("%lf",&n)
#define sc(n) scanf("%c",&n)
#define sff(n,m) scanf("%lf%lf",&n,&m)
#define sfff(n,m,k) scanf("%lf%lf%lf",&n,&m,&k)
#define ss(str) scanf("%s",str)
#define rep(i,a,n) for(int i=a;i<=n;i++)
#define per(i,a,n) for(int i=n;i>=a;i--)
#define mem(a,n) memset(a, n, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define pb push_back
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define mod(x) ((x)%MOD)
#define gcd(a,b) __gcd(a,b)
#define lowbit(x) (x&-x)
#define pii map<int,int>
#define mk make_pair
#define rtl rt<<1
#define rtr rt<<1|1
#define int long long
typedef pair<int,int> PII;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
const int MOD = 1e9 + 7;
const double eps = 1e-9;
const ll INF = 0x3f3f3f3f3f3f3f3fll;
const int inf = 0x3f3f3f3f;
inline int read()
{
int ret = 0, sgn = 1;
char ch = getchar();
while(ch < '0' || ch > '9')
{
if(ch == '-')
sgn = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
ret = ret*10 + ch - '0';
ch = getchar();
}
return ret*sgn;
}
inline void Out(int a){if(a>9) Out(a/10);putchar(a%10+'0');}
int qpow(int m, int k, int mod){int res=1,t=m;while(k){if(k&1)res=res*t%mod;t=t*t%mod;k>>=1;}return res;}
ll gcd(ll a,ll b){return b==0?a : gcd(b,a%b);}
ll lcm(ll a,ll b){return a*b/gcd(a,b);}
ll inv(ll x,ll m){return qpow(x,m-2,m)%m;}
const int N = 5e5+10;
int n,m,q;
int a[N],b[N];
int cnt[N];
int val[N];
int ans = 0;
struct node{
vector<int> to;
}edge[N];
int dfs1(int x,int par)
{
int nn = edge[x].to.size();
int tmp = 1;
for(int i = 0 ; i < nn ; i ++)
{
int xx = edge[x].to[i];
if(xx != par)
tmp += dfs1(xx,x);
}
cnt[x] = tmp;
val[x] = (cnt[x]%MOD*(n-cnt[x])%MOD)%MOD;
return tmp;
}
signed main()
{
signed t;
cin>>t;
while(t--)
{
cin>>n;
for(int i = 0 ; i < n ; i ++)
{
edge[i].to.clear();
cnt[i] = 0;
b[i] = 0;
}
for(int i = 0 ; i < n-1 ; i ++)
{
int x,y;
cin>>x>>y;
x -= 1,y -= 1;
edge[x].to.push_back(y);
edge[y].to.push_back(x);
}
cin>>m;
for(int i = 0 ; i < m ; i ++)
cin>>b[i];
sort(b,b+m);
for(int i = m ; i < n-1 ; i ++)
b[i] = 1;
for(int i = n-1 ; i < m ; i ++)
b[n-2] = (b[n-2]%MOD*b[i]%MOD)%MOD;
for(int i = 0 ; i < n ; i ++)
{
if(edge[i].to.size() == 1)
{
dfs1(i,-1);
break;
}
}
ans = 0;
sort(b,b+n-1,greater<int>());
sort(val,val+n,greater<int>());
for(int i = 0 ; i < n-1 ; i ++)
{
//cout<<val[i]<<" "<<b[i]<<endl;
ans = (val[i]%MOD*b[i]%MOD+ans)%MOD;
}
cout<<ans<<endl;
}
return 0;
}