D:Paths on the Tree
time limit per test
3 seconds
memory limit per test
256 megabytes
input
standard input
output
standard output
You are given a rooted tree consisting of nn vertices. The vertices are numbered from 1 to n, and the root is the vertex 1. You are also given a score array s1,s2,…,sn.
A multiset of k simple paths is called valid if the following two conditions are both true.
- Each path starts from 1.
- Let cici be the number of paths covering vertex ii. For each pair of vertices (u,v)(u,v) (2≤u,v≤n) that have the same parent, |cu−cv|≤1holds.
The value of the path multiset is defined as ∑(i=1-n)cisi.
It can be shown that it is always possible to find at least one valid multiset. Find the maximum value among all valid multisets.
Input
Each test contains multiple test cases. The first line contains a single integer tt (1≤t≤10^4) — the number of test cases. The description of the test cases follows.
The first line of each test case contains two space-separated integers n (2≤n≤2⋅10^5) and k (1≤k≤10^9) — the size of the tree and the required number of paths.
The second line contains n−1space-separated integers p2,p3,…,pn (1≤pi≤n), where pi is the parent of the i-th vertex. It is guaranteed that this value describe a valid tree with root 1.
The third line contains n space-separated integers s1,s2,…,sn(0≤si≤10^4) — the scores of the vertices.
It is guaranteed that the sum of nn over all test cases does not exceed 2⋅10^5.
Output
For each test case, print a single integer — the maximum value of a path multiset.
Example
input
2
5 4
1 2 1 3
6 2 1 5 7
5 3
1 2 1 3
6 6 1 4 10
output
54 56Note
In the first test case, one of optimal solutions is four paths 1→2→3→5, 1→2→3→5, 1→4, 1→4, here c=[4,2,2,2,2]. The value equals to 4⋅6+2⋅2+2⋅1+2⋅5+2⋅7=54.
In the second test case, one of optimal solution is three paths 1→2→3→5, 1→2→3→5, 1→4, here c=[3,2,2,1,2]. The value equals to 3⋅6+2⋅6+2⋅1+1⋅4+2⋅10=56.
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=2e5+10;
int n,k;
vector<int>G[N];
map<int,int>mp[N];
int p[N];
int dfs(int i,int j)
{
if(mp[i].count(j))
{
return mp[i][j];
}
mp[i][j]=j*p[i];
int siz=G[i].size();
if(siz==0)return mp[i][j];
int f=j/siz;
int e=j%siz;
if(e==0)
{
for(auto x:G[i])
{
int a=dfs(x,f);
mp[i][j]+=a;
}
}
else
{
vector<int>v;
for(auto x:G[i])
{
int a=dfs(x,f);
int b=dfs(x,f+1);
mp[i][j]+=a;
v.push_back(b-a);
}
sort(v.begin(),v.end(),greater<int>());
for(int l=0;l<e;l++)
{
mp[i][j]+=v[l];
}
}
return mp[i][j];
}
void solve()
{
cin>>n>>k;
for(int i=1;i<=n;i++)
{
G[i].clear();
mp[i].clear();
}
for(int i=2;i<=n;i++)
{
int x;
cin>>x;
G[x].push_back(i);
}
for(int i=1;i<=n;i++)
{
cin>>p[i];
}
cout<<dfs(1,k)<<"\n";
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int T;
cin>>T;
while(T--)
{
solve();
}
return 0;
}