链接:https://www.nowcoder.com/acm/contest/123/D
来源:牛客网
题意:
在根节点为0的树上(编号:0,1,2...n)。一条边上有两个值 poweri ,numi,相当于网络流:poweri相当于路上的cost,numi相当于容量,相当于汇点为0,源点为所有叶子节点,求最大费用。
样例:
7
0 100 0
1 2 3
2 2 5
1 5 1
2 1 3
3 2 4
4 3 2
输出:
33分析:想到可以贪心优先分配路上总power最大的叶子节点,然后求路上最小值mi,ans加上power*mi,然后把路径上的num都减去最小值mi,重复下去。然后用树链剖分维护就行。(看别人代码不用树链刨分直接修改也能过)。
#include<iostream>
#include<stdio.h>
#include<algorithm>
#include<string.h>
#include<vector>
using namespace std;
typedef long long int ll;
const int inf = 1e9;
const int maxn = 100000;
int n,son[maxn+5],sz[maxn+5],fa[maxn+5],top[maxn+5],tid[maxn+5],Rank[maxn+5],tim,dep[maxn+5];
int num[maxn+5],power[maxn+5],cost[maxn+5],st[maxn+5];
int mi[maxn*4],add[maxn*4];
vector<int>g[maxn+5];
void dfs1(int x,int f)
{
sz[x] = 1, fa[x] = f, son[x]=-1;
int len = g[x].size();
for(int i=0; i<len; i++)
{
int y = g[x][i];
if(y==f) continue;
dep[y] = dep[x] + 1;
cost[y]+=cost[x]+power[y];
dfs1(y,x);
sz[x]+=sz[y];
if(son[x]==-1||sz[son[x]]<sz[y]) son[x] = y;
}
}
void dfs2(int x,int tp)
{
top[x] = tp;
tid[x] = ++tim;
Rank[tim] = x;
if(son[x]!=-1) dfs2(son[x],tp);
for(int i=0,len=g[x].size(); i<len; i++)
{
int y = g[x][i];
if(y!=son[x]&&y!=fa[x])
dfs2(y,y);
}
}
void pushup(int o)
{
mi[o] = min(mi[o*2],mi[o*2+1]);
}
void build(int o,int l,int r)
{
add[o]=0;
if(l==r)
{
mi[o] = num[Rank[l]];
return;
}
int mid = (l+r)>>1;
build(o*2,l,mid);
build(o*2+1,mid+1,r);
pushup(o);
}
void pushdown(int o)
{
if(add[o]!=0)
{
int ls = o*2, rs = o*2+1;
mi[ls]+=add[o], mi[rs]+=add[o];
add[ls]+=add[o], add[rs]+=add[o];
add[o] = 0;
}
}
int query(int o,int l,int r,int L,int R)
{
if(L<=l&&r<=R) return mi[o];
int mid = (l+r)>>1;
pushdown(o);
if(R<=mid) return query(o*2,l,mid,L,R);
else
{
if(mid<L) return query(o*2+1,mid+1,r,L,R);
return min(query(o*2,l,mid,L,R),query(o*2+1,mid+1,r,L,R));
}
}
void updata(int o,int l,int r,int L,int R,int ad)
{
if(L<=l&&r<=R)
{
mi[o]+=ad;
add[o]+=ad;
}
else
{
pushdown(o);
int mid = (l+r)>>1;
if(R<=mid) updata(o*2,l,mid,L,R,ad);
else
{
if(mid<L) updata(o*2+1,mid+1,r,L,R,ad);
else
{
updata(o*2,l,mid,L,R,ad);
updata(o*2+1,mid+1,r,L,R,ad);
}
}
pushup(o);
}
}
int Find(int u,int v)
{
int f1 = top[u], f2 = top[v], tmp = inf;
while(f1!=f2)
{
if(dep[f1]<dep[f2]) swap(f1,f2), swap(u,v);
tmp = min(tmp,query(1,1,tim,tid[f1],tid[u]));
u = fa[f1];
f1 = top[u];
}
if(u==v) return tmp;
if(dep[u]>dep[v]) swap(u,v);
return min(tmp,query(1,1,tim,tid[son[u]],tid[v]));
}
void Updata(int va,int vb,int ad)
{
int f1 = top[va], f2 = top[vb];
while (f1 != f2)
{
if (dep[f1] < dep[f2])
{
swap(f1, f2);
swap(va, vb);
}
updata(1,1,tim,tid[f1],tid[va],ad);
va = fa[f1];
f1 = top[va];
}
if (va == vb) return;
if (dep[va] > dep[vb]) swap(va, vb);
updata(1, 1, tim, tid[son[va]], tid[vb],ad);
}
bool cmp(int x,int y)
{
return cost[x] > cost[y];
}
int main()
{
scanf("%d",&n);
for(int i=0; i<=n; i++) g[i].clear();
for(int i=1,f; i<=n; i++)
{
scanf("%d %d %d",&f,&num[i],&power[i]);
g[f].push_back(i);
}
dep[0] = 0;
dfs1(0,-1);
tim = -1;
dfs2(0,0);
build(1,1,tim);
int cnt = 0;
for(int i=1; i<=n; i++) if(g[i].size()==0) st[++cnt] = i;
sort(st+1,st+cnt+1,cmp);
ll ans = 0;
for(int i=1; i<=cnt; i++)
{
int x = st[i];
int w = Find(x,0);
if(w<=0) continue;
ans += cost[x]*1ll*w;
Updata(x,0,-w);
}
printf("%lld\n",ans);
return 0;
}