树上连通有关背包:【BZOJ4182】shopping &【HDU6566】The Hanged Man

选这两道题是因为这两道题都是树上背包,而且选的点的要求都与连通性有关,而且都是按 dfs 序 DP 来模拟不断加入物品,而且都能用树剖和点分治优化(不过优化的点一个跟子树大小有关一个跟深度有关),比较相似。

【BZOJ4182】shopping

题意:树上多重背包,要求选了的点是一个连通块。

暴力想法设 f u , i f_{u,i} fu,i 表示选了以 u u u 为根且在 u u u 子树内的连通块,花费为 i i i 的最大收益。

如果使用暴力合并子树的方法的话,时间复杂度 O ( n m 2 ) O(nm^2) O(nm2),而且看起来没什么可以优化的地方。

变换思路,我们不考虑合并,而是像普通序列 DP 一样不断添加物品。

添加物品需要按 dfs 序来 DP,设 dfs 了到点 u u u,用 g i g_i gi 表示 dfs 到点 u u u 时花费为 i i i 的最大收益(即到点 u u u 之前的 DP 状态)。

我们考虑点 u u u 的物品我们选不选。如果我们选了点 u u u(设 g 1 g_1 g1 表示选了点 u u u 的 DP 状态),显然先令 g 1 = g g_1=g g1=g ,然后用点 u u u 的物品更新 g 1 g_1 g1(注意这个点的物品至少得选一个),然后进这个点的子树 dfs 更新 g 1 g_1 g1;如果我们不选点 u u u(设 g 2 g_2 g2 表示不选点 u u u 的 DP 状态),那么它的子树也不能选,所以直接令 g 2 = g g_2=g g2=g。最后从 u u u 回溯时对 g 1 g_1 g1 g 2 g_2 g2 取个 max ⁡ \max max 来更新 g g g 即可。

如果我们使用二进制分组优化多重背包,一次加入物品时间就能降到 O ( m log ⁡ D ) O(m\log D) O(mlogD)。(当然也可以不用二进制分组而是用单调队列优化来把 log ⁡ D \log D logD 去掉)

注意如果连通块的根不同,那么进入一个点 u u u 时的 g g g 就不同,出来的值也不同,所以需要枚举连通块的根。

那么如果暴力地枚举每个点作为连通块的根并在其子树内 dfs 求出 f u f_{u} fu,总时间复杂度就为 O ( n 2 m log ⁡ D ) O(n^2m\log D) O(n2mlogD)

接下来有两种优化方式:

  • 树剖。我们不能暴力合并子树的 f f f,但我们可以选择一棵子树继承它的 f f f。所以我们求 f u f_u fu 时先继承重儿子 s s s f s f_s fs,再往其他的轻儿子 dfs。每个点只会被 dfs 共 log ⁡ n \log n logn 次。总时间复杂度 O ( n m log ⁡ n log ⁡ D ) O(nm\log n\log D) O(nmlognlogD)
  • 点分治。我们枚举每一个 r t rt rt,以点分树上以它为根的的子树为范围,在原树上 dfs。显然原树上的每一个连通块都可以被唯一一个 r t rt rt(这个连通块中在点分树上深度最小的那个点)DP 到。总时间复杂度 O ( n m log ⁡ n log ⁡ D ) O(nm\log n\log D) O(nmlognlogD)

点分治做法代码:

#include<bits/stdc++.h>
 
#define N 510
#define M 4010
#define INF 0x7fffffff
 
using namespace std;
 
inline int read()
{
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9')
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=(x<<1)+(x<<3)+(ch^'0');
        ch=getchar();
    }
    return x*f;
}
 
int T,n,m,w[N],c[N],d[N];
int cnt,head[N],to[N<<1],nxt[N<<1];
int nn,maxn,rt,size[N];
int ans;
bool vis[N];
 
void adde(int u,int v)
{
    to[++cnt]=v;
    nxt[cnt]=head[u];
    head[u]=cnt;
}
 
void getsize(int u,int fa)
{
    size[u]=1;
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa||vis[v]) continue;
        getsize(v,u);
        size[u]+=size[v];
    }
}
 
void getroot(int u,int fa)
{
    int maxs=0;
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa||vis[v]) continue;
        getroot(v,u);
        maxs=max(maxs,size[v]);
    }
    maxs=max(maxs,nn-1-size[u]);
    if(maxs<maxn) rt=u,maxn=maxs;
}
 
int f[M],g[N][M];
int tot,ww[35],cc[35];
 
void divide(int u)//二进制分组
{
    tot=0;
    int x=d[u]-1;
    for(int j=1;j<=x;j<<=1)
    {
        x-=j;
        ww[++tot]=j*w[u];
        cc[tot]=j*c[u];
    }
    if(x)
    {
        ww[++tot]=x*w[u];
        cc[tot]=x*c[u];
    }
}
 
void insert(int u)
{
    if(d[u]==1) return;
    divide(u);
    for(int i=1;i<=tot;i++)
        for(int j=m;j>=0;j--)
            if(j-cc[i]>=0&&f[j-cc[i]]!=-1)
                f[j]=max(f[j],f[j-cc[i]]+ww[i]);
}
 
void dfs(int u,int fa)
{
    memcpy(g[u],f,sizeof(g[u]));
    memset(f,-1,sizeof(f));
    for(int j=m;j>=0;j--)//在当前点强制先选一个
        if(j-c[u]>=0&&g[u][j-c[u]]!=-1)
            f[j]=g[u][j-c[u]]+w[u];
    insert(u);
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa||vis[v]) continue;
        dfs(v,u);
    }
    for(int i=0;i<=m;i++) f[i]=max(f[i],g[u][i]);
}
 
void work(int u)
{
    memset(f,-1,sizeof(f));
    f[0]=0;
    dfs(u,0);
    for(int i=0;i<=m;i++)
        ans=max(ans,f[i]);
}
 
void solve(int u)
{
    vis[u]=1;
    work(u);
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(vis[v]) continue;
        getsize(v,0);
        nn=size[v],maxn=INF,getroot(v,0);
        solve(rt);
    }
}
 
int main()
{
    T=read();
    while(T--)
    {
        ans=cnt=0;
        memset(head,0,sizeof(head));
        memset(vis,0,sizeof(vis));
        n=read(),m=read();
        for(int i=1;i<=n;i++) w[i]=read();
        for(int i=1;i<=n;i++) c[i]=read();
        for(int i=1;i<=n;i++) d[i]=read();
        for(int i=1;i<n;i++)
        {
            int u=read(),v=read();
            adde(u,v),adde(v,u);
        }
        getsize(1,0);
        nn=size[1],maxn=INF,getroot(1,0);
        solve(rt);  
        printf("%d\n",ans);
    }
    return 0;
}

【HDU6566】The Hanged Man

题意:树上 01 背包,要求选的点是一个独立集,除了输出最大收益还要输出最大收益的方案数。

暴力想法设 f u , 0 / 1 , i f_{u,0/1,i} fu,0/1,i 表示考虑完以 u u u 为根的子树, u u u 选没选,代价为 i i i 的最大收益。

暴力合并子树是 O ( n m 2 ) O(nm^2) O(nm2) 的,而且也没有什么可优化的地方。

变换思路,我们不考虑合并,而是像普通序列 DP 一样不断添加物品。

按 dfs 序 DP,注意祖先点的选择状态对于后续点的选择有影响,所以需要记录一下祖先的选择状态。

暴力想法设 g s t a , i g_{sta,i} gsta,i 表示当前祖先选择状态为 s t a sta sta,代价为 i i i 的最大收益。然后直接 dfs 并更新 g g g 即可。时间复杂度 O ( n 2 n m ) O(n2^nm) O(n2nm)

接下来有两种优化方法:

  • 树剖。我们优先 dfs 轻儿子,那么 dfs 完一个点后会一直回溯到重链顶端,所以重链上的点的选择状态对后续点的选择是没有影响的,于是 DP 时只需记录每个轻边父亲的选择状态,时间复杂度 O ( n 2 log ⁡ n m ) = O ( n 2 m ) O(n2^{\log n}m)=O(n^2m) O(n2lognm)=O(n2m)
  • 点分治。显然对于一个点 u u u,在原树中和它相邻的点只可能是点分树中 u u u 的祖先或者是点分树中 u u u 的子树,于是 DP 时只需记录点分树上祖先的选择状态,时间复杂度 O ( n 2 log ⁡ n m ) = O ( n 2 m ) O(n2^{\log n}m)=O(n^2m) O(n2lognm)=O(n2m)

点分治做法代码:

#include<cstring>
#include<iostream>
#include<assert.h>

#define N 55
#define M 5010
#define ll long long
#define INF 0x7fffffff

using namespace std;

inline int read()
{
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9')
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=(x<<1)+(x<<3)+(ch^'0');
        ch=getchar();
    }
    return x*f;
}

int T,n,m,a[N],b[N],fa[N];
int cnt,head[N],nxt[N<<1],to[N<<1];
int nn,rt,maxn,size[N];
int f[N<<2][M];
ll g[N<<2][M];
bool vis[N];

void adde(int u,int v)
{
    to[++cnt]=v;
    nxt[cnt]=head[u];
    head[u]=cnt;
}

void dfs(int u)
{
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa[u]) continue;
        fa[v]=u;
        dfs(v);
    }
}

void getsize(int u,int fa)
{
    size[u]=1;
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa||vis[v]) continue;
        getsize(v,u);
        size[u]+=size[v];
    }
}

void getroot(int u,int fa)
{
    int nmax=0;
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa||vis[v]) continue;
        getroot(v,u);
        nmax=max(nmax,size[v]);
    }
    nmax=max(nmax,nn-size[u]);
    if(nmax<maxn) rt=u,maxn=nmax;
}

int top,sta[7];

bool beside(int u,int v)
{
    return fa[u]==v||fa[v]==u;
}

void solve(int u,int dep)
{
    vis[u]=1;
    int fms=(1<<(dep-1))-1;
    for(int fs=0;fs<=fms;fs++)
    {
        bool choose=1;
        for(int i=1;i<=top;i++)
        {
            if(((fs>>(i-1))&1)&&beside(u,sta[i]))
            {
                choose=0;
                break;
            }
        }
        if(choose)
        {
            int us=fs|(1<<(dep-1));
            for(int i=0;i<=m;i++)
            {
                if(i-a[u]>=0&&f[fs][i-a[u]]!=-1)
                {
                    f[us][i]=f[fs][i-a[u]]+b[u];
                    g[us][i]=g[fs][i-a[u]];
                }
            }
        }
    }
    sta[++top]=u;
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(vis[v]) continue;
        getsize(v,0);
        nn=size[v],maxn=INF,getroot(v,0);
        solve(rt,dep+1);
    }
    top--;
    int tmp=(1<<(dep-1));
    for(int fs=0;fs<=fms;fs++)
    {
        for(int i=0;i<=m;i++)
        {
            if(f[fs][i]<f[fs|tmp][i])
            {
                f[fs][i]=f[fs|tmp][i];
                g[fs][i]=g[fs|tmp][i];
            }
            else if(f[fs][i]==f[fs|tmp][i])
                g[fs][i]+=g[fs|tmp][i];
            f[fs|tmp][i]=-1,g[fs|tmp][i]=0;
        }
    }
}

int main()
{
    T=read();
    for(int Case=1;Case<=T;Case++)
    {
        cnt=0;
        memset(head,0,sizeof(head));
        memset(f,-1,sizeof(f));
        memset(g,0,sizeof(g));
        memset(vis,0,sizeof(vis));
        n=read(),m=read();
        for(int i=1;i<=n;i++)
            a[i]=read(),b[i]=read();
        for(int i=1;i<n;i++)
        {
            int u=read(),v=read();
            adde(u,v),adde(v,u);
        }
        dfs(1);
        getsize(1,0);
        nn=size[1],maxn=INF,getroot(1,0);
        f[0][0]=0,g[0][0]=1;
        solve(rt,1);
        printf("Case %d:\n",Case);
        for(int i=1;i<m;i++)
            printf("%lld ",g[0][i]);
        printf("%lld\n",g[0][m]);
    }
    return 0;
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值