题意:给两棵基于同一点集的带边权树,记 lca ( x , y ) , depth ( x ) \operatorname{lca}(x,y),\operatorname{depth}(x) lca(x,y),depth(x) 为第一棵树上的 lca、到根的边长度之和, lca ′ ( x , y ) , depth ′ ( x ) \operatorname{lca}'(x,y),\operatorname{depth}'(x) lca′(x,y),depth′(x) 为第二棵树的,最大化
depth ( x ) + depth ( y ) − ( d e p t h ( lca ( x , y ) ) + depth ′ ( lca ′ ( x , y ) ) ) \operatorname{depth}(x)+\operatorname{depth}(y)-(\operatorname{depth(\operatorname{lca}(x,y))}+\operatorname{depth}'(\operatorname{lca}'(x,y))) depth(x)+depth(y)−(depth(lca(x,y))+depth′(lca′(x,y)))
n ≤ 366666 n\leq 366666 n≤366666
这个式子非常诡异,先推一下发现等于这个
1 2 ( dist ( x , y ) + depth ( x ) + depth ( y ) − depth ′ ( lca ′ ( x , y ) ) ) \frac 12(\operatorname{dist}(x,y)+\operatorname{depth}(x)+\operatorname{depth}(y)-\operatorname{depth}'(\operatorname{lca}'(x,y))) 21(dist(x,y)+depth(x)+depth(y)−depth′(lca′(x,y)))
左边是个距离,而右边只有一个二元函数,考虑对第一棵树分治
我们用点分治或边分治可以把 dist ( x , y ) \operatorname{dist}(x,y) dist(x,y) 拆成两项分别只与 x x x 和 y y y 有关的东西,就可以和 depth \operatorname{depth} depth 合并。现在的问题时怎么处理右边的东西。
不管是点分治还是边分治,每次计算时都有两个点集 S , T S,T S,T,要统计所有 x ∈ S , y ∈ T x\in S,y\in T x∈S,y∈T 的贡献。
考虑虚树。在第二棵树上用之前的代价标记 S , T S,T S,T 中的点,然后建出虚树,维护子树内两种集合中的权值最大值,在 lca \operatorname{lca} lca 处统计贡献。
这样复杂度是 O ( S log S ) \Omicron(S\log S) O(SlogS),其中 S S S 为两个集合的点集大小。所以只能边分治。
用欧拉序做 O ( n log n ) − O ( 1 ) \Omicron(n\log n)-\Omicron(1) O(nlogn)−O(1) lca \operatorname{lca} lca,总复杂度可以做到严格 O ( n log n ) \Omicron(n\log n) O(nlogn)
码量虽大但没什么细节,还是比较好写的。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#include <algorithm>
#define MAXN 1000005
#define MAXM 2000005
using namespace std;
typedef long long ll;
const ll INF=1e18;
inline int read()
{
int ans=0,f=1;
char c=getchar();
while (!isdigit(c)) (c=='-')&&(f=-1),c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return f*ans;
}
struct edge{int u,v,w;}e[MAXM];
int head[MAXN],nxt[MAXM],cnt=1;
inline void addnode(int u,int v,int w)
{
e[++cnt]=(edge){u,v,w};
nxt[cnt]=head[u];
head[u]=cnt;
}
vector<edge> E[MAXN];
ll dis[MAXN];
int vis[MAXM],n,tot;
void dfs(int u)
{
vis[u]=1;
if ((int)E[u].size()<=3)
{
for (int i=0;i<(int)E[u].size();i++)
{
int v=E[u][i].v,w=E[u][i].w;
if (vis[v]) continue;
dfs(v),addnode(u,v,w),addnode(v,u,w);
}
return;
}
int cur[2]={++tot,++tot},pos=0;
addnode(u,cur[0],0),addnode(cur[0],u,0);
addnode(u,cur[1],0),addnode(cur[1],u,0);
for (int i=0;i<(int)E[u].size();i++)
{
int v=E[u][i].v,w=E[u][i].w;
if (vis[v]) continue;
E[cur[pos]].push_back((edge){cur[pos],v,w}),pos^=1;
}
dfs(cur[0]),dfs(cur[1]);
}
void dfs(int u,int f)
{
for (int i=head[u];i;i=nxt[i])
if (e[i].v!=f)
dis[e[i].v]=dis[u]+e[i].w,dfs(e[i].v,u);
}
int rt,siz[MAXN];
ll mn;
void findrt(int u,int f,int sum)
{
siz[u]=1;
for (int i=head[u];i;i=nxt[i])
if (!vis[i>>1]&&e[i].v!=f)
{
findrt(e[i].v,u,sum);
if (max(siz[e[i].v],sum-siz[e[i].v])<mn)
mn=max(siz[e[i].v],sum-siz[e[i].v]),rt=i;
siz[u]+=siz[e[i].v];
}
}
namespace VT
{
edge e[MAXM];
int head[MAXN],nxt[MAXM],cnt;
inline void addnode(int u,int v,int w)
{
e[++cnt]=(edge){u,v,w};
nxt[cnt]=head[u];
head[u]=cnt;
}
int dfn[MAXN],lis[MAXM],LOG[MAXM],st[MAXM][21],tim;
ll dis[MAXN];
inline bool cmp(const int& x,const int& y){return dfn[x]<dfn[y];}
void dfs(int u,int f)
{
lis[dfn[u]=++tim]=u;
for (int i=head[u];i;i=nxt[i])
if (e[i].v!=f)
{
dis[e[i].v]=dis[u]+e[i].w;
dfs(e[i].v,u);
lis[++tim]=u;
}
}
inline int lca(int x,int y)
{
x=dfn[x],y=dfn[y];
if (x>y) swap(x,y);
int t=LOG[y-x+1];
return min(st[x][t],st[y-(1<<t)+1][t],cmp);
}
void input()
{
LOG[0]=-1;
for (int i=1;i<MAXM;i++) LOG[i]=LOG[i>>1]+1;
for (int i=1;i<n;i++)
{
int u,v,w;
u=read(),v=read(),w=read();
addnode(u,v,w),addnode(v,u,w);
}
dfs(1,0);
for (int i=1;i<=tim;i++) st[i][0]=lis[i];
for (int j=1;j<21;j++)
for (int i=1;i+(1<<(j-1))<=tim;i++)
st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1],cmp);
}
vector<int> s,son[MAXN];
ll val[MAXN],x[MAXN],y[MAXN],ans;
int type[MAXN];
inline void insert(int u,ll v,int t){val[u]=v,type[u]=t,s.push_back(u);}
void dfs(int u)
{
x[u]=y[u]=-INF;
if (type[u]==1) x[u]=val[u];
if (type[u]==2) y[u]=val[u];
for (int i=0;i<(int)son[u].size();i++)
{
dfs(son[u][i]);
ans=max(ans,max(x[u]+y[son[u][i]],x[son[u][i]]+y[u])-2*dis[u]);
x[u]=max(x[u],x[son[u][i]]),y[u]=max(y[u],y[son[u][i]]);
}
}
int stk[MAXN],tp;
ll solve()
{
sort(s.begin(),s.end(),cmp);
int siz=s.size();
for (int i=0;i<siz-1;i++)
s.push_back(lca(s[i],s[i+1]));
sort(s.begin(),s.end(),cmp);
s.erase(unique(s.begin(),s.end()),s.end());
tp=0;
for (int i=0;i<(int)s.size();i++)
{
while (tp&&lca(stk[tp],s[i])!=stk[tp]) --tp;
if (tp) son[stk[tp]].push_back(s[i]);
stk[++tp]=s[i];
}
ans=-INF;
dfs(stk[1]);
for (int i=0;i<(int)s.size();i++) son[s[i]].clear();
s.clear();
return ans;
}
}
void dfs(int u,int f,ll d,int type)
{
if (u<=n) VT::insert(u,dis[u]+d,type);
for (int i=head[u];i;i=nxt[i])
if (!vis[i>>1]&&e[i].v!=f)
dfs(e[i].v,u,d+e[i].w,type);
}
ll ans=-INF;
void solve(int sum)
{
if (mn==INF) return;
vis[rt>>1]=1;
dfs(e[rt].v,0,0,1);
dfs(e[rt].u,0,0,2);
ans=max(ans,VT::solve()+e[rt].w);
int sz=siz[e[rt].v],cur=rt;
mn=INF,findrt(e[cur].v,0,sz),solve(sz);
mn=INF,findrt(e[cur].u,0,sum-sz),solve(sum-sz);
}
int main()
{
tot=n=read();
for (int i=1;i<n;i++)
{
int u,v,w;
u=read(),v=read(),w=read();
E[u].push_back((edge){u,v,w}),E[v].push_back((edge){v,u,w});
}
VT::input();
dfs(1);
dfs(1,0);
memset(vis,0,sizeof(vis));
mn=INF,findrt(1,0,tot),solve(tot);
for (int i=1;i<=n;i++) ans=max(ans,2*(dis[i]-VT::dis[i]));
cerr<<ans<<'\n';
cout<<(ans>>1);
return 0;
}