解析
感觉是左偏树的神题了.
首先有一个比较显然的结论,一个合法的方案中,两个叶子到它们 lca \text{lca} lca 的距离必须相等.
考虑设计 dp \text{dp} dp :
f i , x f_{i,x} fi,x 表示 i i i 的子树中,所有叶子到它的距离为 x x x 的最小代价.
考虑这个函数如何向父亲合并.
设一个结点到父亲的距离为
c
i
c_i
ci .
朴素
dp
\text{dp}
dp ,就有:
f
i
,
x
=
∑
j
∈
s
o
n
i
min
v
≥
0
x
f
j
,
x
−
v
+
∣
c
j
−
v
∣
f_{i,x}=\sum_{j\in son_i}\min_{v\geq0}^xf_{j,x-v}+|c_j-v|
fi,x=j∈soni∑v≥0minxfj,x−v+∣cj−v∣
这玩意显然复杂度爆炸啊…
换个角度,考虑
f
f
f 函数本身的性质.
不难想到,原来的函数应该是一个下凸的线性函数.
s
o
n
i
son_i
soni 的函数如何向
i
i
i 合并?
一开始,这个函数似乎就是简单的所有子节点函数相加合成.
并且,由于
c
i
c_{i}
ci 的存在,这个函数肯定要往右移
c
i
c_i
ci .
假设移动完长这样:
但是,考虑到有 i i i 连向 f a i fa_i fai 的边,有些修改可以在这里一起改,就不必各自麻烦了,所以肯定函数会变化,准确的说,变得更好.
设斜率为
0
0
0 的区间为
[
L
,
R
]
[L,R]
[L,R] .
然后我们发现
R
R
R 右侧还有好多斜率大于
1
1
1 的地方.
大概的实际意义就是每个儿子都各自修改.
这就很亏.
所以对于
f
i
,
x
f_{i,x}
fi,x 我们干脆在先全部调整成距离为
R
R
R,支付
f
i
,
R
f_{i,R}
fi,R 的代价,再把当前结点连向父亲的边权值增大
x
−
R
x-R
x−R .
这样就把
R
R
R 右侧的斜率全都变成了
1
1
1 .
变成这样:
另一边可以类似的搞吗?
差不多,但是有个问题…
边的权值非负!
所以我们斜率为
1
1
1 的区间往左增加的长度最多为
c
i
c_i
ci .
后面的函数就往左顺延
c
i
c_i
ci .
也就是变成:
别忘了,本来这个函数是整体右移
c
i
c_i
ci .
这里左边的斜率大于1的区间又往左移动了
c
i
c_i
ci .
所以其实根本位置没变.
总结一下的话,这个函数合并到父亲后,就是 把斜率为
0
0
0 的区间
[
L
,
R
]
[L,R]
[L,R] 右移
c
i
c_i
ci ,把
R
R
R 右边的函数斜率全改为
1
1
1 ,斜率为
−
1
-1
−1 的区间往右延长
c
i
c_i
ci 补上
L
L
L 右移的空缺.
实现上,建一个可并堆维护所有的拐点,令相邻拐点斜率差为
1
1
1 (如果两端之间斜率差大于
1
1
1 就插多个横坐标相同的拐点),那么其实就把
R
R
R 右侧的所有拐点弹掉,并把
L
L
L 和
R
R
R 的坐标加上
c
i
c_i
ci 就行了.
最后合并到根之后,我们只有拐点的横坐标,如何求出答案(也就是
f
1
,
L
f_{1,L}
f1,L )呢?
注意到,
f
1
,
0
f_{1,0}
f1,0 其实就是所有边权之和,能很方便的求出来,又因为我们知道相邻两个端点的斜率差均为
1
1
1 ,
[
L
,
R
]
[L,R]
[L,R] 的斜率为
0
0
0 ,那么我们只需要倒着推一遍就行了.
代码
#include<bits/stdc++.h>
const int N=1e6+100;
const int mod=1e9+7;
#define ll long long
using namespace std;
inline ll read() {
ll x(0),f(1);char c=getchar();
while(!isdigit(c)) {if(c=='-')f=-1;c=getchar();}
while(isdigit(c)) {x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
int n,m;
int fa[N],fv[N],son[N];
int rt[N],tot,ls[N],rs[N],dis[N];
ll val[N];
int merge(int x,int y){
//printf("merge x=%d y=%d\n",x,y);
if(!x||!y) return x|y;
if(val[x]<val[y]) swap(x,y);
rs[x]=merge(rs[x],y);
if(dis[ls[x]]<dis[rs[x]]) swap(ls[x],rs[x]);
dis[x]=dis[rs[x]]+1;
return x;
}
inline int pop(int x){
return merge(ls[x],rs[x]);
}
ll sum;
void debug(int x){
if(!x) return;
printf("x=%d ls=%d rs=%d val=%lld\n",x,ls[x],rs[x],val[x]);
debug(ls[x]);debug(rs[x]);
return;
}
int main(){
dis[0]=-1;
n=read();m=read();
for(int i=2;i<=n+m;i++){
fa[i]=read();son[fa[i]]++;sum+=(fv[i]=read());
}
for(int i=n+m;i>=2;i--){
ll l(0),r(0);
//printf("i=%d\n",i);
if(i<=n){
while(--son[i]) rt[i]=pop(rt[i]);
r=val[rt[i]];rt[i]=pop(rt[i]);
l=val[rt[i]];rt[i]=pop(rt[i]);
//printf("ok");
}
val[++tot]=l+fv[i];val[++tot]=r+fv[i];
rt[i]=merge(rt[i],merge(tot-1,tot));//printf("OK rtfa=%d\n",rt[fa[i]]);
printf("\ni=%d:\n",i);debug(rt[i]);
rt[fa[i]]=merge(rt[fa[i]],rt[i]);
}
//debug(rt[1]);
while(son[1]--){
rt[1]=pop(rt[1]);
//printf("\nrt=%d\n",rt[1]);
}
while(rt[1]){
sum-=val[rt[1]];rt[1]=pop(rt[1]);
//printf("\nrt=%d\n",rt[1]);
}
printf("%lld\n",sum);
}
/*
1
281239
*/