HDU 5293 Tree chain problem 树形DP+LCA+树链剖分+线段树

3 篇文章 0 订阅
3 篇文章 0 订阅

Tree chain problem


Time Limit: 6000/3000 MS (Java/Others)    Memory Limit: 65536/65536 K (Java/Others)



Problem Description
Coco has a tree, whose vertices are conveniently labeled by 1,2,…,n.
There are m chain on the tree, Each chain has a certain weight. Coco would like to pick out some chains any two of which do not share common vertices.
Find out the maximum sum of the weight Coco can pick
 

Input
The input consists of several test cases. The first line of input gives the number of test cases T (T<=10).
For each tests: 
First line two positive integers n, m.(1<=n,m<=100000)
The following (n - 1) lines contain 2 integers ai bi denoting an edge between vertices ai and bi (1≤ai,bi≤n),
Next m lines each three numbers u, v and val(1≤u,v≤n,0<val<1000), represent the two end points and the weight of a tree chain.
 

Output
For each tests:
A single integer, the maximum number of paths.
 

Sample Input
  
  
1 7 3 1 2 1 3 2 4 2 5 3 6 3 7 2 3 4 4 5 3 6 7 3
 

Sample Output
  
  
6
Hint
Stack expansion program: #pragma comment(linker, "/STACK:1024000000,1024000000")
 
这是这道题的第二篇博客,上一篇使用的是DFS序和树状数组,这一篇使用的是树链剖分和线段树,主要是为了练一练树链剖分。
#pragma comment(linker, "/STACK:1024000000,1024000000") //扩栈
#include <set>
#include <map>
#include <stack>
#include <cmath>
#include <queue>
#include <cstdio>
#include <bitset>
#include <string>
#include <vector>
#include <iomanip>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <functional>
#define maxn 100010
#define ll long long
#define inf 0x7fffffff
using namespace std;
int n, m;
//邻接表
struct edge
{
        int u, v;
        int next;
} e[maxn * 2];
int cnt;
int pre[maxn];
void add_edge(int u, int v) //加边
{
        e[cnt].u = u;
        e[cnt].v = v;
        e[cnt].next = pre[u];
        pre[u] = cnt++;
        e[cnt].u = v;
        e[cnt].v = u;
        e[cnt].next = pre[v];
        pre[v] = cnt++;
}

//树链
struct node
{
        int u, v, w;
        int lca;
} p[maxn];

vector<int> vec[maxn]; //存储以某节点为LCA的链编号


//-------------------树链剖分--------------------

int dep[maxn], fa[maxn], siz[maxn], son[maxn];
// 深度数组   父节点  子节点数目  重儿子
int tid[maxn], ran[maxn], top[maxn], tim;;
// 节点的编号 编号的节点 所在链顶端节点

int f[20][maxn], Lev;

void dfs1(int u, int father, int d)
{
        dep[u] = d;
        fa[u] = father;
        f[0][u] = father;
        siz[u] = 1;
        for (int i = pre[u]; ~i; i = e[i].next)
        {
                int v = e[i].v;
                if (v != father)
                {
                        dfs1(v, u, d + 1);
                        siz[u] += siz[v];
                        if (son[u] == -1 || siz[v] > siz[son[u]])
                        {
                                son[u] = v;
                        }
                }
        }
}

void dfs2(int u, int tp)
{
        top[u] = tp;
        tid[u] = ++tim;
        ran[tid[u]] = u;
        if (son[u] == -1)
        {
                return;
        }
        dfs2(son[u], tp);
        for (int i = pre[u]; ~i; i = e[i].next)
        {
                int v = e[i].v;
                if (v != son[u] && v != fa[u])
                {
                        dfs2(v, v);
                }
        }
}

//--------------LCA倍增在线算法-------------

bool vis[maxn];
void bfs(int rt)//防爆栈,未使用
{
        queue<int> q;
        q.push(rt);
        f[0][rt] = 0, dep[rt] = 1, vis[rt] = 1;
        while (!q.empty())
        {
                int fa = q.front();
                q.pop();
                for (int i = pre[fa] ; ~i ; i = e[i].next)
                {
                        int x = e[i].v;
                        if (!vis[x])
                        {
                                dep[x] = dep[fa] + 1;
                                f[0][x] = fa , vis[x] = 1;
                                q.push(x);
                        }
                }
        }
}

int LCA(int x , int y)
{
        if (dep[x] > dep[y])
        {
                swap(x , y);
        }
        for (int i = Lev ; i >= 0 ; -- i)
                if (dep[y] - dep[x] >> i & 1)
                {
                        y = f[i][y];
                }
        if (x == y)
        {
                return y;
        }
        for (int i = Lev ; i >= 0 ; -- i)
                if (f[i][x] != f[i][y])
                {
                        x = f[i][x] , y = f[i][y];
                }
        return f[0][x];
}

int get_kth_anc(int x , int k) //找x的第k个祖先 未使用
{
        for (int i = 0 ; i <= Lev ; ++ i)
                if (k >> i & 1)
                {
                        x = f[i][x];
                }
        return x;
}


//--------------线段树--------------

#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
const int NV = 100005;
//线段树数组
int sumx[NV << 2];
int dp[NV << 2];
//原数组
int tsumx[NV];
int tdp[NV];
void PushUp(int rt, int * sum) //向上更新
{
        sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
}
void build(int l, int r, int rt , int * sum) //建树
{
        if (l == r)
        {
                sum[rt] = 0;//初始化
                return ;
        }
        int mid = (l + r) >> 1;
        build(lson, sum);
        build(rson, sum);
        PushUp(rt, sum);
}
void update(int L, int c, int l, int r, int rt , int * sum) //单点更新
{
        if (L == l && l == r)
        {
                sum[rt] = c;
                return ;
        }
        int mid = (l + r) >> 1;
        if (L <= mid)
        {
                update(L , c , lson, sum);
        }
        else
        {
                update(L , c , rson, sum);
        }
        PushUp(rt, sum);
}
int query(int L, int R, int l, int r, int rt , int * sum) //区间查询
{
        if (L <= l && r <= R)
        {
                return sum[rt];
        }
        int mid = (l + r) >> 1;
        int ret = 0;
        if (L <= mid)
        {
                ret += query(L , R , lson, sum);
        }
        if (mid < R)
        {
                ret += query(L , R , rson, sum);
        }
        return ret;
}
int query_s(int x, int y, int * sum)//树链剖分 区间查询
{
        int ans = 0;
        while (top[x] != top[y])
        {
                if (dep[top[x]] < dep[top[y]])
                {
                        swap(x, y); //x所在链在y链下方
                }
                ans = ans + query(tid[top[x]], tid[x], 1, n, 1, sum); //查询x链
                x = fa[top[x]];
        }
        if (dep[x] > dep[y])
        {
                swap(x, y); //x、y在同一条链且x在y上方
        }
        ans = ans + query(tid[x] , tid[y], 1, n, 1, sum); //查询y链
        return ans;
}

//--------------树形DP与线段树维护------------

void solve(int s, int t, int father)
{
        tdp[s] = tsumx[s] = 0;
        for (int i = pre[s]; ~i; i = e[i].next)
        {
                int v = e[i].v;
                if (v != father)
                {
                        solve(v, t, s);
                        tsumx[s] += tdp[v];
                }
        }
        tdp[s] = tsumx[s];
        for (int i = 0; i < vec[s].size(); i++)
        {
                int u = p[vec[s][i]].u;
                int v = p[vec[s][i]].v;
                //状态转移方程
                int temp = query_s(u, v, sumx) - query_s(u, v, dp) + tsumx[s];
                tdp[s] = max(tdp[s], temp + p[vec[s][i]].w);
        }
        //更新线段树
        update(tid[s], tsumx[s], 1, n, 1, sumx);
        update(tid[s], tdp[s], 1, n, 1, dp);
}
void init()
{
        cnt = 0;
        tim = 0;
        memset(pre, -1, sizeof(pre));
        memset(son, -1, sizeof(son));
        memset(f, 0, sizeof(f));
        memset(top, 0, sizeof(top));
        memset(tsumx, 0, sizeof(tsumx));
        memset(tdp, 0, sizeof(tdp));
        for (int i = 0; i < maxn; i++)
        {
                vec[i].clear();
        }
}
int main()
{
        int t;
        scanf("%d", &t);
        while (t--)
        {
                init();
                scanf("%d %d", &n, &m);
                for (int i = 1; i < n; i++)
                {
                        int u, v;
                        scanf("%d %d", &u, &v);
                        add_edge(u, v);
                }
                //建链
                dfs1(1, 1, 1);
                dfs2(1, 1);
                int js;
                for (js = 1 ; 1 << js < n ; ++ js) //更新lca数组
                        for (int i = 1 ; i <= n ; ++ i)
                        {
                                f[js][i] = f[js - 1][f[js - 1][i]];
                        }
                Lev = js - 1;
                for (int i = 0; i < m; i++)
                {
                        scanf("%d %d %d", &p[i].u, &p[i].v, &p[i].w);
                        p[i].lca = LCA(p[i].u, p[i].v);
                        vec[p[i].lca].push_back(i);
                }
                build(1, n, 1, sumx); //建sumx树
                build(1, n, 1, dp); //建dp树
                solve(1, n, 0);
                printf("%d\n", tdp[1]);
        }
        return 0;
}


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值