Description
有一棵n个点的树,边有边权。
有m条额外边,第i条的边权为ai,对于每个i∈[0,m],问加入前i条额外边后,从1出发经过所有树边至少一次最后回到1的最短路径的长度。
n,m<=10^5
Solution
容易发现这题分为两部分,先求出an[i]表示选择树上i条不相交的链的长度最大值,然后对于每个前缀i,求一个j,使得an[j]-前i个a的前j小最大
an的话可以考虑类似反悔的操作:每次取一条直径,然后将直径取反
后面的话由于两个都是凸函数可以直接二分,当然也可以猜出单调性直接指针扫过去
问题变成了,支持路径取反,动态询问树上直径(最长链)
考虑用LCT维护,每条链维护往上走和往下走的最长链以及子树内的最长链,以及权值取反之后的这两个东西,然后取反就相当于打标记。每个点再维护两个set表示虚儿子往下走的最长链和虚儿子字数内的最长链然后就可以了
写起来可能有点自闭
Code
#include <set>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
typedef long long ll;
int read() {
char ch;
for(ch=getchar();ch<'0'||ch>'9';ch=getchar());
int x=ch-'0';
for(ch=getchar();ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x;
}
const int N=2e5+5;
const ll inf=1e16;
struct Path{
int u,v;ll w;
Path(){}
Path(int u,int v,ll w):
u(u),v(v),w(w){}
Path operator + (const ll x)const{return Path(u,v,w+x);}
Path operator - (const ll x)const{return Path(u,v,w-x);}
Path operator + (const Path& x)const{
Path ret;
ret.u=u;ret.v=x.u;
if (ret.u<ret.v) swap(ret.u,ret.v);
ret.w=w+x.w;
return ret;
}
bool operator < (const Path& x)const{
if (w!=x.w) return w<x.w;
if (u!=x.u) return u<x.u;
return v<x.v;
}
};
struct Data{
Path pre,suf,ans;ll sum;
Data(){}
Data(Path pre,Path suf,Path ans,ll sum):
pre(pre),suf(suf),ans(ans),sum(sum){}
Data operator + (const Data& x)const{
Data ret;
ret.pre=max(pre,x.pre+sum);
ret.suf=max(suf+x.sum,x.suf);
ret.ans=max(ans,x.ans);
ret.ans=max(ret.ans,suf+x.pre);
ret.sum=sum+x.sum;
return ret;
}
};
namespace LCT{
int f[N],t[N][2],p[N],sta[N],top;
multiset<Path> pre[N],ans[N];
ll val[N];
Data a[N][2],ra[N][2],dp[N];
bool rev[N],sig[N];
void upd(int x) {
int ls=t[x][0],rs=t[x][1];
a[x][0]=ra[x][0]=Data(dp[x].pre+val[x],dp[x].pre+val[x],dp[x].ans,val[x]);
a[x][1]=ra[x][1]=Data(dp[x].pre-val[x],dp[x].pre-val[x],dp[x].ans,-val[x]);
if (ls) {
fo(i,0,1) {
a[x][i]=a[ls][i]+a[x][i];
ra[x][i]=ra[x][i]+ra[ls][i];
}
}
if (rs) {
fo(i,0,1) {
a[x][i]=a[x][i]+a[rs][i];
ra[x][i]=ra[rs][i]+ra[x][i];
}
}
}
int son(int x) {return t[f[x]][1]==x;}
void rotate(int x) {
int y=f[x],z=son(x);
if (f[y]) t[f[y]][son(y)]=x;
else p[x]=p[y],p[y]=0;
if (t[x][1-z]) f[t[x][1-z]]=y;
t[y][z]=t[x][1-z];t[x][1-z]=y;
f[x]=f[y];f[y]=x;
upd(y);upd(x);
}
void reverse(int x) {
if (!x) return;
swap(t[x][0],t[x][1]);
swap(a[x][0],ra[x][0]);
swap(a[x][1],ra[x][1]);
rev[x]^=1;
}
void flip(int x) {
if (!x) return;
swap(a[x][0],a[x][1]);
swap(ra[x][0],ra[x][1]);
val[x]=-val[x];sig[x]^=1;
}
void down(int x) {
if (sig[x]) {
flip(t[x][0]);flip(t[x][1]);
sig[x]=0;
}
if (rev[x]) {
reverse(t[x][0]);reverse(t[x][1]);
rev[x]=0;
}
}
void remove(int x,int y) {
do {sta[++top]=x;x=f[x];} while (x!=y);
for(;top;down(sta[top--]));
}
void splay(int x,int y) {
remove(x,y);
while (f[x]!=y) {
if (f[f[x]]!=y)
if (son(x)==son(f[x])) rotate(f[x]);
else rotate(x);
rotate(x);
}
}
void get(int x) {
dp[x].pre=*pre[x].rbegin();
dp[x].ans=*ans[x].rbegin();
if (pre[x].size()>=2) {
auto smx=pre[x].rbegin(),mx=smx;smx++;
dp[x].ans=max(dp[x].ans,*smx+*mx+val[x]);
}
}
void add(int x,int y) {
pre[x].insert(a[y][0].pre);
ans[x].insert(a[y][0].ans);
get(x);
}
void del(int x,int y) {
pre[x].erase(pre[x].find(a[y][0].pre));
ans[x].erase(ans[x].find(a[y][0].ans));
get(x);
}
void access(int x) {
int y=0;
for(;x;y=x,x=p[x]) {
splay(x,0);
if (t[x][1]) {
add(x,t[x][1]);
f[t[x][1]]=0;p[t[x][1]]=x;
}
if (y) {
del(x,y);
f[y]=x;p[y]=0;
}
t[x][1]=y;
upd(x);
}
}
void make_root(int x) {access(x);splay(x,0);reverse(x);}
void link(int u,int v) {
make_root(u);make_root(v);
p[v]=u;add(u,v);upd(u);
}
bool check(int u,int v) {
make_root(u);access(v);
splay(v,0);
for(;u&&u!=v;u=f[u]);
return u==v;
}
void modify(int u,int v) {make_root(u);access(v);splay(v,0);flip(v);}
}
int n,m,a[N];
ll tot,an[N];
multiset<int> s,t;
void Ins(int x) {
t.insert(x);
if (t.size()&&s.size()) {
int x=*s.rbegin(),y=*t.begin();
if (x<=y) return;
s.erase(s.find(x));t.erase(t.find(y));
s.insert(y);t.insert(x);
tot+=y-x;
}
}
int main() {
freopen("love.in","r",stdin);
freopen("love.out","w",stdout);
n=read();m=read();
fo(i,1,n) {
LCT::pre[i].insert(Path(i,i,0));
LCT::ans[i].insert(Path(i,i,0));
LCT::get(i);LCT::upd(i);
}
ll sum=0;
fo(i,1,n-1) {
int x=read(),y=read(),z=read(),v=i+n;sum+=z<<1;
LCT::val[v]=z;
LCT::pre[v].insert(Path(0,0,-inf));
LCT::ans[v].insert(Path(0,0,-inf));
LCT::get(v);LCT::upd(v);
if (LCT::check(x,y)) continue;
LCT::link(x,v);LCT::link(v,y);
}
s.clear();
fo(i,1,m) a[i]=read();
an[0]=sum;
fo(i,1,m) {
LCT::make_root(n);
Path p=LCT::a[n][0].ans;
sum-=p.w;an[i]=sum;
LCT::modify(p.u,p.v);
}
printf("%lld ",an[0]);
int j=0;
fo(i,1,m) {
Ins(a[i]);
for(;j<i&&an[j]+tot>an[j+1]+tot+*t.begin();j++) {
int x=*t.begin();t.erase(t.begin());
tot+=x;s.insert(x);
}
printf("%lld ",an[j]+tot);
}
return 0;
}