题解 - 树上游走(二)(上海月赛2024.7甲组T1)

题目描述

给定一个棵由编号为 1~n 这 n 个节点构成的树,及 m 个关键节点,树的第 i 条双向边连接 ui 和 vi ,长度为 wi 。你可以从树上任何一个节点出发开始游走,到任何一个节点结束,但整个游走过程中必须经过给定的 m 个关键节点。
请问,应当如何设计路线,能够使得整个游走过程中,经过的路径长度最短。

输入格式

输入第一行,两个正整数 n,m
接下来 n−1 行,每行 3 个正整数,其中第 i+1 行分别表示第 i 条边的 ui,vi,wi。
输入最后行,m 个正整数,分别表示 m 个关键节点的编号 ci。

输出格式

输出共一行,表示所求最短长度。

数据范围

1≤m≤n≤1e5,1​ ≤ ui,vi,ci ≤n, 1≤w ≤1e4

样例数据

输入:

5 3
1 2 2
2 3 4
2 4 5
1 5 1
3 4 5

输出:

15

说明:

3–>2–>1–>5–>1–>2–>4
4 2 1 1 2 5

分析

树形dp + 换根

  1. 首先考虑对于起点固定,应该怎么做。

假设起点为st,假设m个关键节点分布在st的k个子树内,则我们将只经过其中1个子树一次,然后其余k-1个子树的每条边经过两次。

一个很显然的贪心是,为了使总路程最短,我们应该满足 只经过1次的子树的路程和 最大。

可以dfs来求

其中f[u]代表从u节点访问u子树的所有关键点然后再返回u节点需要的总路程(即所有路径均为两倍),
mx[i][0]和mx[i][1]分别表示i的子树的最大值和次大值。(至于为什么还需要维护次大值,后续会讲到)

void dfs(int u,int fa){
    if(flag[u]) cnt[u] = 1;

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];
        if(j != fa){
            dfs(j,u);
            
            if(cnt[j]){
                cnt[u] += cnt[j];
                f[u] += f[j] + 2 * val;

                int tmp = mx[j][0] + val;
                if(tmp > mx[u][0]){
                    mx[u][1] = mx[u][0],id[u][1] = id[u][0];
                    mx[u][0] = tmp,id[u][0] = j;
                }else if(tmp > mx[u][1]){
                    mx[u][1] = tmp,id[u][1] = j;
                }
            } 
        }
    }
}

最终结果即为 f[st] - mx[st][0]

  1. 然后考虑起点不固定

可以采用换根dp来写

对于 u → j u \rightarrow j uj 这条边,根从 u u u 转移到 j j j

若 j 子树内关键节点数为0,则 f[j] = f[u] + 2 * val
若 j 子树内关键节点数为m,则 f[j] 不变;
否则, f[j] = f[u]

若 j 子树内关键节点数为m,则还需要更新mx[j][0]和mx[j][1]的值。(具体细节见代码)

void dp(int u,int fa){
    res = min(res,f[u] - mx[u][0]);

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];

        if(j != fa){
            if(!cnt[j]) f[j] = f[u] + 2 * val;
            else if(cnt[j] == m) f[j] = f[j];
            else f[j] = f[u];

            if(cnt[j] != m){
                int tmp;
                if(id[u][0] == j) tmp = mx[u][1] + val;
                else tmp = mx[u][0] + val;

                if(tmp > mx[j][0]){
                    mx[j][1] = mx[j][0],id[j][1] = id[j][0];
                    mx[j][0] = tmp,id[j][0] = u;
                }else if(tmp > mx[j][1]){
                    mx[j][1] = tmp,id[j][1] = u;
                }
           }

            dp(j,u);
        }
    }
}

完整代码

#include<bits/stdc++.h>

using namespace std;

typedef long long LL;

const int N = 1e5 + 10;

int n,m;
bool flag[N];
int h[N],e[N * 2],w[N * 2],ne[N * 2],idx;
LL cnt[N],f[N],mx[N][2],id[N][2],res;

void add(int a,int b,int c){
    e[idx] = b,w[idx] = c,ne[idx] = h[a],h[a] = idx++;
}

void dfs(int u,int fa){
    if(flag[u]) cnt[u] = 1;

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];
        if(j != fa){
            dfs(j,u);
            
            if(cnt[j]){
                cnt[u] += cnt[j];
                f[u] += f[j] + 2 * val;

                int tmp = mx[j][0] + val;
                if(tmp > mx[u][0]){
                    mx[u][1] = mx[u][0],id[u][1] = id[u][0];
                    mx[u][0] = tmp,id[u][0] = j;
                }else if(tmp > mx[u][1]){
                    mx[u][1] = tmp,id[u][1] = j;
                }
            } 
        }
    }
}

void dp(int u,int fa){
    res = min(res,f[u] - mx[u][0]);

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];

        if(j != fa){
            if(!cnt[j]) f[j] = f[u] + 2 * val;
            else if(cnt[j] == m) f[j] = f[j];
            else f[j] = f[u];

            if(cnt[j] != m){
                int tmp;
                if(id[u][0] == j) tmp = mx[u][1] + val;
                else tmp = mx[u][0] + val;

                if(tmp > mx[j][0]){
                    mx[j][1] = mx[j][0],id[j][1] = id[j][0];
                    mx[j][0] = tmp,id[j][0] = u;
                }else if(tmp > mx[j][1]){
                    mx[j][1] = tmp,id[j][1] = u;
                }
           }

            dp(j,u);
        }
    }
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);

    memset(h,-1,sizeof h);

    cin >> n >> m;
    for(int i = 0,u,v,w;i < n - 1;i++){
        cin >> u >> v >> w;
        add(u,v,w),add(v,u,w);
    }
    for(int i = 1,j;i <= m;i++){
        cin >> j;
        flag[j] = true;
    }

    dfs(1,-1);

    res = f[1] - mx[1][0];

    dp(1,-1);

    cout << res;

    return 0;
}
  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值