BZOJ 3252 攻略 dfs序+线段树
题目链接:右转进入题目
题目大意:给定一棵以1为根的n个点的树,树有点权且点权为正整数,可以选择k条以根作为起点的路径,每条路径的价值即这条路径上所有点的点权之和。
但是选择一条路径之后,这条路径上的所有点的点权会变成0。(也就是说,这k条路径中被重复选择的点,其点权只能被计算一次)。
求最大价值之和。n<=200000。
题解:
一开始以为是dp或者费用流之类的,根本没有想到是线段树。
首先维护一个前缀和,每个点x的前缀和即x到根的所有点权之和。
然后很明显就要贪心的来做了,第一次肯定选取前缀和最大的那个点。
但是选了这个点之后,这个点到根的路径上的所有点的点权被“取走(也就是变成0)”了;
那么观察可知,如果把x这个点取走,辣么说它和它的子树中所有点的前缀和都要减去这个点的权值(注意不是前缀和)。
于是我们每次选取一个前缀和最大的点,从这个点开始往上走一直走到不能走为止,其间对于每个走过的点都进行上述“取走”操作。
显然我们的单次修改操作都是针对一颗子树的,所以显然想到是dfs序。
又因为我们需要维护这样一个数据结构,实现区间减法和询问整个区间的最值,那么显然就是来一发线段树就可以啦~
然后看复杂度,显然每次选择x就把x到根所有的点都取走是不划算的(因为有可能这条路径上的点已经取走了,不用再取一遍了)
于是我们的策略是记录每个点是否被删除,这样选择x的话,从x开始一直往上走,走到一个已经被取过的点就停止。
那么由于每个点只可能被取走一次,取走一次的复杂度是在线段树上进行区间操作的O(lgn),所以复杂度就是O(nlgn)。
PS:这道题写错了一点地方调了两个小时……
附上代码:
//BZOJ 3235
#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
#define MAXN 200010
#define ull long long
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
vector<int> g[MAXN];
long long val[MAXN],ans;int A[MAXN];
int rt,dfs_clock,father[MAXN],times[MAXN],L[MAXN],R[MAXN];
bool del[MAXN];
struct answer{
ull val;int pos;
answer()
{
val=pos=0;
}
bool operator>(const answer &ans)
{
return this->val>ans.val;
}
bool operator=(const answer &ans)
{
this->val=ans.val;
this->pos=ans.pos;
}
void operator+=(ull v)
{
this->val+=v;
}
};
struct segment{
int lef,rig;
ull plus_tag;
answer maxn;
segment *lc,*rc;
}*root;
void dfs(int x)
{
times[L[x]=++dfs_clock]=x;
for(int i=g[x].size()-1;i>=0;i--)
{
val[g[x][i]]+=val[x];
dfs(g[x][i]);
}
R[x]=dfs_clock;
}
void push_up(segment* &rt)
{
rt->maxn=rt->lc->maxn>rt->rc->maxn?
rt->lc->maxn:rt->rc->maxn;
}
void push_down(segment* &rt)
{
rt->lc->maxn+=rt->plus_tag;
rt->lc->plus_tag+=rt->plus_tag;
rt->rc->maxn+=rt->plus_tag;
rt->rc->plus_tag+=rt->plus_tag;
rt->plus_tag=0;
}
void build_segment(segment* &rt,int lef,int rig)
{
rt=new segment;
rt->plus_tag=0;
rt->lef=lef;rt->rig=rig;
rt->rc=rt->lc=NULL;
if(lef==rig)
{
rt->maxn.val=val[times[lef]];
rt->maxn.pos=lef;
return;
}
int mid=(lef+rig)>>1;
build_segment(rt->lc,lef,mid);
build_segment(rt->rc,mid+1,rig);
push_up(rt);return;
}
void update_segment(segment* &rt,int s,int t,ull v)
{
int l=rt->lef,r=rt->rig;
if(s<=l&&r<=t)
{
rt->maxn+=v;
rt->plus_tag+=v;
return;
}
int mid=(l+r)>>1;
if(rt->plus_tag) push_down(rt);
if(s<=mid) update_segment(rt->lc,s,t,v);
if(mid<t) update_segment(rt->rc,s,t,v);
push_up(rt);return;
}
void debug_tree(int x)
{
debug(x)sp;debug(L[x])sp;debug(R[x])sp;debug(val[x])ln;
for(int i=g[x].size()-1;i>=0;i--)
debug_tree(g[x][i]);
}
void debug_segment(segment *rt)
{
debug(rt->lef)sp;debug(rt->rig)sp;debug(rt->plus_tag)sp;
debug(rt->maxn.val)sp;debug(rt->maxn.pos)ln;
if(rt->lef==rt->rig) return;
debug_segment(rt->lc);debug_segment(rt->rc);
}
int main()
{
int n,k;scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++)
scanf("%d",&A[i]),val[i]=A[i];
for(int i=1;i<n;i++)
{
int u,v;scanf("%d%d",&u,&v);
g[u].push_back(v);father[v]=u;
}
dfs_clock=0;dfs(rt=1);father[rt]=0;
build_segment(root,1,n);
// debug_tree(rt);
while(k--)
{
answer Ans=root->maxn;
if(Ans.val==0) break;
int u=times[Ans.pos];ans+=Ans.val;
while(u&&!del[u])
{
del[u]=true;
update_segment(root,L[u],R[u],-A[u]);
u=father[u];
}
// debug_segment(root);cout ln;
}
printf("%lld\n",ans);return 0;
}