题目链接
题目解法
其实我第一反应是 d s u o n t r e e dsu\;on\;tree dsuontree,但那个不好写,就改写点分治了
树上的静态问题,考虑点分治
考虑一条
u
−
>
l
c
a
−
>
v
u->lca->v
u−>lca−>v 的路径,令
u
−
>
l
c
a
u->lca
u−>lca 的路径上不包括
l
c
a
lca
lca 的
s
u
m
U
P
=
∑
v
a
l
i
∗
l
e
n
a
f
t
e
r
i
sumUP=\sum{val_i*len_{after\;i}}
sumUP=∑vali∗lenafteri,总和为
t
o
t
=
∑
v
a
l
i
tot=\sum val_i
tot=∑vali,
l
c
a
−
>
v
lca->v
lca−>v 的路径上不包括
l
c
a
lca
lca 的
s
u
m
D
O
W
N
=
∑
v
a
l
i
∗
l
e
n
a
f
t
e
r
i
sumDOWN=\sum{val_i*len_{after\;i}}
sumDOWN=∑vali∗lenafteri,长度为
l
e
n
len
len
考虑这条路径的答案为
s
u
m
U
P
+
t
o
t
∗
(
l
e
n
+
1
)
+
s
u
m
D
O
W
N
+
v
a
l
l
c
a
∗
(
l
e
n
+
1
)
sumUP+tot*(len+1)+sumDOWN+val_{lca}*(len+1)
sumUP+tot∗(len+1)+sumDOWN+vallca∗(len+1)
考虑已知
l
c
a
−
>
v
lca->v
lca−>v 的路径的信息,问题是如何找到最大的
s
u
m
U
P
+
t
o
t
∗
(
l
e
n
+
1
)
sumUP+tot*(len+1)
sumUP+tot∗(len+1)
考虑到这是一个一次函数的形式,所以想到李超树维护
k
=
t
o
t
,
b
=
s
u
m
U
P
k=tot,b=sumUP
k=tot,b=sumUP
这里清空李超树不能全部遍历一遍,需要把用过的清空
注意到整体修改李超树的时间复杂度是
O
(
l
o
g
n
)
O(log\;n)
O(logn)
点分治自带
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
所以时间复杂度为
O
(
n
l
o
g
2
n
)
O(nlog^2n)
O(nlog2n)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N(150100),inf(1e18);
struct Segment_Tree{
int mxk[N<<2],mxb[N<<2];
int calc(int k,int b,int p){
if(k==-1&&b==-1) return -inf;
return k*p+b;
}
void insert(int l,int r,int x,int k,int b){
if(calc(k,b,l)<=calc(mxk[x],mxb[x],l)&&calc(k,b,r)<=calc(mxk[x],mxb[x],r)) return;
if(calc(k,b,l)>calc(mxk[x],mxb[x],l)&&calc(k,b,r)>calc(mxk[x],mxb[x],r)){ mxk[x]=k,mxb[x]=b;return;}
int mid=(l+r)>>1;
if(calc(k,b,mid)>calc(mxk[x],mxb[x],mid)) swap(mxk[x],k),swap(mxb[x],b);
if(mxk[x]>k) insert(l,mid,x<<1,k,b);
else insert(mid+1,r,x<<1^1,k,b);
}
void clear(int l,int r,int x){
if(mxk[x]==-1&&mxb[x]==-1) return;
mxk[x]=mxb[x]=-1;
if(l==r) return;
int mid=(l+r)>>1;
clear(l,mid,x<<1),clear(mid+1,r,x<<1^1);
}
int query(int l,int r,int x,int p){
if(l==r) return calc(mxk[x],mxb[x],p);
int mid=(l+r)>>1,t=calc(mxk[x],mxb[x],p);
if(mid>=p) return max(t,query(l,mid,x<<1,p));
else return max(t,query(mid+1,r,x<<1^1,p));
}
}T;
struct Node{
int sumUP,sumDOWN,len,tot;
}stk[N];
int top;
int n,val[N],ans;
int e[N<<1],ne[N<<1],h[N],idx;
bool vis[N],cut[N];
inline int read(){
int FF=0,RR=1;
char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
return FF*RR;
}
int get_size(int u,int fa){
int siz=1;
for(int i=h[u];~i;i=ne[i]) if(!vis[e[i]]&&e[i]!=fa) siz+=get_size(e[i],u);
return siz;
}
int get_root(int u,int fa,int &rt,int tot){
if(vis[u]) return 0;
int siz=1,ms=0;
for(int i=h[u];~i;i=ne[i]){
if(e[i]==fa) continue;
int t=get_root(e[i],u,rt,tot);
siz+=t,ms=max(ms,t);
}
ms=max(ms,tot-siz);
if(ms<=tot/2) rt=u;
return siz;
}
void dfs(int u,int fa,int sumUP,int sumDOWN,int tot,int len){
bool leaf=1;
for(int i=h[u];~i;i=ne[i]){
int v=e[i];
if(v!=fa&&!vis[v]) leaf=0,dfs(v,u,sumUP+val[v]*(len+1),sumDOWN+tot+val[v],tot+val[v],len+1);
}
if(leaf) stk[++top]={sumUP,sumDOWN,len,tot};
}
void solve(int rt){
get_root(rt,-1,rt,get_size(rt,-1));
vis[rt]=1;
top=0;
for(int i=h[rt];~i;i=ne[i]) if(!vis[e[i]]){ dfs(e[i],rt,val[e[i]],val[e[i]],val[e[i]],1);cut[top]=1;}
T.clear(1,n,1);
// cout<<rt<<'\n';
// for(int i=1;i<=top;i++) cout<<stk[i].len<<' '<<stk[i].sumDOWN<<' '<<stk[i].sumUP<<' '<<stk[i].tot<<'\n';
// cout<<'\n';
for(int i=1;i<=top;i++) ans=max(ans,max(stk[i].sumUP+stk[i].tot+val[rt],stk[i].sumDOWN+val[rt]*(stk[i].len+1)));
for(int i=1;i<=top;){
int j=i;
while(!cut[j]) ans=max(ans,T.query(1,n,1,stk[j].len+1)+stk[j].sumDOWN+val[rt]*(stk[j].len+1)),j++;
ans=max(ans,T.query(1,n,1,stk[j].len+1)+stk[j].sumDOWN+val[rt]*(stk[j].len+1));
while(!cut[i]) T.insert(1,n,1,stk[i].tot,stk[i].sumUP),i++;
T.insert(1,n,1,stk[i].tot,stk[i].sumUP),i++;
}
T.clear(1,n,1);
for(int i=top;i>=1;){
ans=max(ans,T.query(1,n,1,stk[i].len+1)+stk[i].sumDOWN+val[rt]*(stk[i].len+1));
int j=i-1;
while(j&&!cut[j]) ans=max(ans,T.query(1,n,1,stk[j].len+1)+stk[j].sumDOWN+val[rt]*(stk[j].len+1)),j--;
T.insert(1,n,1,stk[i].tot,stk[i].sumUP);i--;
while(i&&!cut[i]) T.insert(1,n,1,stk[i].tot,stk[i].sumUP),i--;
}
for(int i=1;i<=top;i++) cut[i]=0;
for(int i=h[rt];~i;i=ne[i]) if(!vis[e[i]]) solve(e[i]);
}
void add(int x,int y){ e[idx]=y,ne[idx]=h[x],h[x]=idx++;}
signed main(){
n=read();
memset(h,-1,sizeof(h));
for(int i=1;i<n;i++){
int x=read(),y=read();
add(x,y),add(y,x);
}
for(int i=1;i<=n;i++) val[i]=read();
T.clear(1,n,1);
solve(1);
printf("%lld",ans);
return 0;
}