代码:
//============================================================================
// Name : test.cpp
// Author : Assassin_upc
// Version :
// Copyright : Your copyright notice
// Description : Hello World in C++, Ansi-style
//============================================================================
//树形dp
//设置dp[i][0]表示根结点为i的子树上不含有坏点的最小花费
//dp[i][1]代表根结点为i 的子树上含有坏点的最小花费
//则:
//当点i是叶子结点:
// 当i点是坏点:dp[i][0] = inf,dp[i][1] = 0;
// 否则:dp[i][0] = 0,dp[i][1] = inf;
//当点i不是叶子结点:
// 当i是坏点:dp[i][0] = inf,dp[i][1] += min(dp[j][0],dp[j][1] + v);
// 否则:dp[i][0] += min(dp[j][0],dp[j][1] + v)
// dp[i][1] = min(dp[i][0] - min(dp[j][0],dp[j][1] + v) + dp[j][1]);
// 对于上一个式子,dp[i][0] - min(dp[j][0],dp[j][1] + v)是因为dp[i][0]就是这样求和得到的。
#include
using namespace std;
#define maxn (100000 + 10)
#define mod (1000000007)
typedef long long int LLI;
typedef pair
PII;
const LLI inf = ((1ll << 62) - 1);
vector
tree[maxn];
LLI dp[maxn][2];
int col[maxn];
bool vis[maxn];
void dfs(int x){
if(tree[x].size() == 1 && vis[tree[x][0].first] == false){
dp[x][1] = 0;
dp[x][0] = inf;
if(col[x] == 0) swap(dp[x][1],dp[x][0]);
}
else{
vis[x] = false;
for(int i = 0;i < tree[x].size();i ++){
if(vis[tree[x][i].first] == false) continue;
int nxt = tree[x][i].first;
dfs(nxt);
}
if(col[x] == 0){
dp[x][0] = 0;
dp[x][1] = inf;
for(int i = 0;i < tree[x].size();i ++){
if(vis[tree[x][i].first] == false) continue;
int nxt = tree[x][i].first;
dp[x][0] += min(dp[nxt][0],dp[nxt][1] + tree[x][i].second);
}
for(int i = 0;i < tree[x].size();i ++){
if(vis[tree[x][i].first] == false) continue;
int nxt = tree[x][i].first;
dp[x][1] = min(dp[x][1],dp[x][0] - min(dp[nxt][0],dp[nxt][1] + tree[x][i].second) + dp[nxt][1]);
}
}else{
dp[x][0] = inf;
dp[x][1] = 0;
for(int i = 0;i < tree[x].size();i ++){
if(vis[tree[x][i].first] == false) continue;
int nxt = tree[x][i].first;
dp[x][1] += min(dp[nxt][0],dp[nxt][1] + tree[x][i].second);
}
}
vis[x] = true;
}
}
int main() {
// freopen("/home/acmclub/Documents/in","r",stdin);
// freopen("/home/acmclub/Documents/out_wa","w",stdout);
// printf("%lld\n",inf);
int t;
scanf("%d",&t);
while(t --){
int n,k;
memset(col,0,sizeof(col));
memset(vis,true,sizeof(vis));
scanf("%d%d",&n,&k);
for(int i = 1;i < n;i ++){
int x,y,w;
scanf("%d%d%d",&x,&y,&w);
tree[x].push_back(PII(y,w));
tree[y].push_back(PII(x,w));
}
for(int i = 1;i <= k;i ++){
int temp;
scanf("%d",&temp);
col[temp] = 1;
}
dfs(0);
LLI ans = inf;
if(vis[0] == 1){
ans = min(ans,dp[0][1]);
}else {
ans = min(ans,dp[0][0]);
ans = min(ans,dp[0][1]);
}
/*for(int i = 0;i < n; i++){
printf("i = %d: %lld %lld\n",i,dp[i][0],dp[i][1]);
}*/
printf("%lld\n",ans);
for(int i = 0; i < n;i ++){
tree[i].clear();
}
}
return 0;
}