锟题x2
以下用$a\rightarrow b$表示端点为$a,b$的链
把式子写成$(h_1(x)+h_1(y)-h_1(lca))-h_2(lca')$,第一部分就是$x\rightarrow rt$和$y\rightarrow rt$的并的总长
考虑对第一棵树边分治,假设分治到$(u,v)$,我们想要统计所有跨过$(u,v)$的$x\rightarrow y$
设在树$1$上$fa_v=u$,对于$u$这边的点$x$,令$f_x=-\infty,g_x=dis(x,u\rightarrow rt)$,对$v$这边的点$y$,令$f_y=h_1(y),g_y=-\infty$,那么$h_1(x)+h_1(y)-h_1(lca)=g_x+f_y$(将另外的$f,g$设为$-\infty$是为了防止统计到不跨过$(u,v)$的情况)
所以我们可以将当前分治范围内的点拿出来在树$2$上建虚树,在虚树上统计答案即可
最后不要忘了统计$x=y$的答案...
#include<stdio.h>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long ll;
const int inf=2147483647;
const ll linf=922337203685477580ll;
int n;
struct pr{
int to,v;
pr(int a=0,int b=0){to=a;v=b;}
};
struct tree1{
int h[733340],nex[1466670],to[1466670],v[1466670],M;
vector<pr>g[366670];
void ins(int a,int b,int c){
M++;
to[M]=b;
v[M]=c;
nex[M]=h[a];
h[a]=M;
}
void add(int a,int b,int c){
ins(a,b,c);
ins(b,a,c);
}
int N;
void dfs(int fa,int x){
vector<pr>::iterator it;
int p=0;
for(it=g[x].begin();it!=g[x].end();it++){
if(it->to!=fa){
if(p){
N++;
add(p,N,0);
add(N,it->to,it->v);
p=N;
}else{
add(x,it->to,it->v);
p=x;
}
}
}
for(it=g[x].begin();it!=g[x].end();it++){
if(it->to!=fa)dfs(x,it->to);
}
}
ll dis[733340];
int fa[733340],dep[733340];
void dfs(int x){
dep[x]=dep[fa[x]]+1;
for(int i=h[x];i;i=nex[i]){
if(to[i]!=fa[x]){
fa[to[i]]=x;
dis[to[i]]=dis[x]+v[i];
dfs(to[i]);
}
}
}
void gao(){
int i,x,y,z;
for(i=1;i<n;i++){
scanf("%d%d%d",&x,&y,&z);
g[x].push_back(pr(y,z));
g[y].push_back(pr(x,z));
}
M=1;
N=n;
dfs(0,1);
dfs(1);
}
}t1;
struct tree2{
int h[366670],nex[733340],to[733340],v[733340],M;
void add(int a,int b,int c){
M++;
to[M]=b;
v[M]=c;
nex[M]=h[a];
h[a]=M;
}
int dfn[366670],mn[733340][20],dep[366670],lg[733340];
ll dis[366670];
void dfs(int fa,int x){
dfn[x]=++M;
mn[M][0]=x;
dep[x]=dep[fa]+1;
for(int i=h[x];i;i=nex[i]){
if(to[i]!=fa){
dis[to[i]]=dis[x]+v[i];
dfs(x,to[i]);
mn[++M][0]=x;
}
}
}
int qmin(int x,int y){return dep[x]<dep[y]?x:y;}
int query(int l,int r){
int k=lg[r-l+1];
return qmin(mn[l][k],mn[r-(1<<k)+1][k]);
}
int lca(int x,int y){
if(dfn[x]>dfn[y])swap(x,y);
return query(dfn[x],dfn[y]);
}
void gao(){
int i,j,x,y,z;
for(i=1;i<n;i++){
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
M=0;
dfs(0,1);
for(j=1;j<20;j++){
for(i=1;i+(1<<j)-1<=M;i++)mn[i][j]=qmin(mn[i][j-1],mn[i+(1<<(j-1))][j-1]);
}
for(i=2;i<=M;i++)lg[i]=lg[i>>1]+1;
}
}t2;
bool cmp(int x,int y){return t2.dfn[x]<t2.dfn[y];}
bool vis[1466670];
int siz[733340],p[366670],M;
void dfs1(int fa,int x){
if(x<=n)p[++M]=x;
siz[x]=1;
for(int i=t1.h[x];i;i=t1.nex[i]){
if(!vis[i]&&t1.to[i]!=fa){
dfs1(x,t1.to[i]);
siz[x]+=siz[t1.to[i]];
}
}
}
int al,mn,cn;
void dfs2(int fa,int x){
for(int i=t1.h[x];i;i=t1.nex[i]){
if(!vis[i]&&t1.to[i]!=fa){
dfs2(x,t1.to[i]);
if(abs(al-2*siz[t1.to[i]])<mn){
mn=abs(al-2*siz[t1.to[i]]);
cn=i;
}
}
}
}
ll f[733340],g[733340];
void dfs3(int fa,int x,ll d){
g[x]=d;
f[x]=-linf;
for(int i=t1.h[x];i;i=t1.nex[i]){
if(!vis[i]&&t1.to[i]!=fa)dfs3(x,t1.to[i],t1.to[i]==t1.fa[x]?0:d+t1.v[i]);
}
}
void dfs4(int fa,int x){
f[x]=t1.dis[x];
g[x]=-linf;
for(int i=t1.h[x];i;i=t1.nex[i]){
if(!vis[i]&&t1.to[i]!=fa)dfs4(x,t1.to[i]);
}
}
ll ans;
struct vtree{
int h[366670],nex[366670],to[366670],M;
void add(int a,int b){
M++;
to[M]=b;
nex[M]=h[a];
h[a]=M;
}
void dfs(int x){
for(int i=h[x];i;i=nex[i]){
dfs(to[i]);
ans=max(ans,max(f[x]+g[to[i]],g[x]+f[to[i]])-t2.dis[x]);
f[x]=max(f[x],f[to[i]]);
g[x]=max(g[x],g[to[i]]);
}
h[x]=0;
}
void clear(int x){
f[x]=g[x]=-linf;
for(int i=h[x];i;i=nex[i])clear(to[i]);
}
}vt;
int st[366670],tp;
void insert(int x){
if(!tp){
st[++tp]=x;
return;
}
int l=t2.lca(x,st[tp]);
while(tp>1&&t2.dep[st[tp-1]]>t2.dep[l]){
vt.add(st[tp-1],st[tp]);
tp--;
}
if(t2.dep[st[tp]]>t2.dep[l]){
vt.add(l,st[tp]);
tp--;
}
if(t2.dep[st[tp]]<t2.dep[l])st[++tp]=l;
st[++tp]=x;
}
void build(){
int i;
sort(p+1,p+M+1,cmp);
tp=0;
vt.M=0;
for(i=1;i<=M;i++)insert(p[i]);
for(i=1;i<tp;i++)vt.add(st[i],st[i+1]);
}
void solve(int x){
int y;
M=0;
dfs1(0,x);
al=siz[x];
mn=inf;
cn=0;
dfs2(0,x);
if(cn==0)return;
vis[cn]=vis[cn^1]=1;
x=t1.to[cn];
y=t1.to[cn^1];
if(t1.dep[x]>t1.dep[y])swap(x,y);
build();
vt.clear(st[1]);
dfs3(0,x,0);
dfs4(0,y);
vt.dfs(st[1]);
solve(x);
solve(y);
}
int main(){
scanf("%d",&n);
t1.gao();
t2.gao();
ans=-linf;
solve(1);
for(int i=1;i<=n;i++)ans=max(ans,t1.dis[i]-t2.dis[i]);
printf("%lld",ans);
}