题意
给出两棵树
T
1
,
T
2
T_1,T_2
T1,T2,定义新图中两个点的距离为
d
i
s
1
(
i
,
j
)
+
d
i
s
2
(
i
,
j
)
dis_1(i,j)+dis_2(i,j)
dis1(i,j)+dis2(i,j),其中
d
i
s
k
(
i
,
j
)
dis_k(i,j)
disk(i,j)表示在
T
k
T_k
Tk中
i
i
i到
j
j
j的距离。求新图的最小生成树。
n
≤
100000
n\le 100000
n≤100000
分析
先对第二棵树进行点分治,对每个分治中心考虑跨过该点的路径中哪些是有用的。设 d e p k dep_k depk表示当前分治子树中 k k k到分治中心的距离,并把分治子树中的点在第一棵树中的虚树建出来,同时对每个点 x x x新建一个点 x ′ x' x′,记新建的点为源点, m a [ x ′ ] = x ma[x']=x ma[x′]=x,并在两点间连边权为 d e p x dep_x depx的边,那么两点在新图中的边的边权就是 d e p x + d e p y + d i s 1 ( x , y ) dep_x+dep_y+dis_1(x,y) depx+depy+dis1(x,y)。
然后对虚树做两次树型dp,求出 p r e x pre_x prex表示距离 x x x最近的原点是谁, f x f_x fx表示距离。对于虚树上的每条边 ( x , y , w ) (x,y,w) (x,y,w),把 ( m a [ p r e [ x ] ] , m a [ p r e [ y ] ] , w + f [ x ] + f [ y ] ) (ma[pre[x]],ma[pre[y]],w+f[x]+f[y]) (ma[pre[x]],ma[pre[y]],w+f[x]+f[y])加入到有用的边的集合当中。最后对有用的边的集合作做小生成树即可。
下面证明没有被加入的边一定不会出现在最小生成树中。对于两个源点 u , v u,v u,v之间的路径,必然有连续的一段满足 p r e x = u pre_x=u prex=u和另一段满足 p r e x = v pre_x=v prex=v。若路径上的点 x x x均有 p r e x = u pre_x=u prex=u或 p r e x = v pre_x=v prex=v则这两点在新图中的路径必然会被加入。否则必存在一点 k k k使得 p r e k ≠ u , v pre_k\not=u,v prek=u,v,则 u u u和 v v v到 p r e k pre_k prek的路径均比 u u u到 v v v之间的路径要短,故结论成立。
时间复杂度 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)。
代码
#include<bits/stdc++.h>
#define pb push_back
typedef long long LL;
const int N=100005;
const int inf=2e8+5;
int n,cnt0,f[N];
struct data{int x,y,w;}edg[N*20];
class Tree
{
public:
int cnt,last[N];
struct edge{int to,next,w;}e[N*6];
void addedge(int u,int v,int w)
{
e[++cnt].to=v;e[cnt].w=w;e[cnt].next=last[u];last[u]=cnt;
e[++cnt].to=u;e[cnt].w=w;e[cnt].next=last[v];last[v]=cnt;
}
void init()
{
for (int i=1;i<n;i++)
{
int x,y,z;scanf("%d%d%d",&x,&y,&z);
addedge(x,y,z);
}
}
};
bool cmp(int x,int y);
class Tree1:public Tree
{
public:
int tim,dfn[N],dep[N],ls[N*2],dis[N],fa[N][18],stack[N],f[N*2],pre[N*2];
int get_lca(int x,int y)
{
if (dep[x]<dep[y]) std::swap(x,y);
for (int i=16;i>=0;i--)
if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
if (x==y) return x;
for (int i=16;i>=0;i--)
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
void dfs(int x)
{
dfn[x]=++tim;
dep[x]=dep[fa[x][0]]+1;
for (int i=1;i<=16;i++) fa[x][i]=fa[fa[x][i-1]][i-1];
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa[x][0]) dis[e[i].to]=dis[x]+e[i].w,fa[e[i].to][0]=x,dfs(e[i].to);
}
void add(int u,int v,int w)
{
e[++cnt].to=v;e[cnt].w=w;e[cnt].next=ls[u];ls[u]=cnt;
e[++cnt].to=u;e[cnt].w=w;e[cnt].next=ls[v];ls[v]=cnt;
}
void dp1(int x,int fa)
{
if (x>n) pre[x]=x,f[x]=0;
else f[x]=inf;
for (int i=ls[x];i;i=e[i].next) if (e[i].to!=fa)
{
dp1(e[i].to,x);
if (f[e[i].to]+e[i].w<f[x]) f[x]=f[e[i].to]+e[i].w,pre[x]=pre[e[i].to];
}
}
void dp2(int x,int fa)
{
for (int i=ls[x];i;i=e[i].next) if (e[i].to!=fa)
{
if (f[x]+e[i].w<f[e[i].to]) f[e[i].to]=f[x]+e[i].w,pre[e[i].to]=pre[x];
dp2(e[i].to,x);
}
}
void build(std::vector<int> pts,std::vector<int> dep2)
{
int tot=pts.size(),tmp=cnt;
std::vector<int> vec;
for (int i=0;i<tot;i++) add(pts[i],pts[i]+n,dep2[i]),vec.pb(pts[i]+n);
std::sort(pts.begin(),pts.end(),cmp);
int top=0;stack[++top]=1;vec.pb(1);
for (int i=0;i<tot;i++)
{
int x=pts[i];
if (x==stack[top]) continue;
int lca=get_lca(x,stack[top]);
if (lca==stack[top]) {stack[++top]=x;vec.pb(x);continue;}
while (dep[stack[top-1]]>=dep[lca])
add(stack[top],stack[top-1],dis[stack[top]]-dis[stack[top-1]]),top--;
if (dep[stack[top]]>dep[lca])
add(stack[top],lca,dis[stack[top]]-dis[lca]),top--;
if (stack[top]!=lca) stack[++top]=lca,vec.pb(lca);
stack[++top]=x;vec.pb(x);
}
while (top>1)
add(stack[top],stack[top-1],dis[stack[top]]-dis[stack[top-1]]),top--;
dp1(1,0);
dp2(1,0);
for (int i=tmp+1;i<=cnt;i+=2)
{
int u=e[i].to,v=e[i+1].to,w=e[i].w;
if (pre[u]!=pre[v]) edg[++cnt0]=(data){pre[u]-n,pre[v]-n,w+f[u]+f[v]};
}
cnt=tmp;
for (int x:vec) ls[x]=0;
}
}t1;
bool cmp(int x,int y)
{
return t1.dfn[x]<t1.dfn[y];
}
class Tree2:public Tree
{
public:
int size[N],mx[N],sum,rt;
bool vis[N];
std::vector<int> pts,dis;
void get_root(int x,int fa)
{
size[x]=1;mx[x]=0;
for (int i=last[x];i;i=e[i].next) if (e[i].to!=fa&&!vis[e[i].to])
{
get_root(e[i].to,x);
size[x]+=size[e[i].to];
mx[x]=std::max(mx[x],size[e[i].to]);
}
mx[x]=std::max(mx[x],sum-size[x]);
if (!rt||mx[x]<mx[rt]) rt=x;
}
void get(int x,int fa,int d)
{
pts.pb(x);dis.pb(d);
for (int i=last[x];i;i=e[i].next)
if (!vis[e[i].to]&&e[i].to!=fa) get(e[i].to,x,d+e[i].w);
}
void work(int x)
{
vis[x]=1;
pts.clear();dis.clear();
pts.pb(x);dis.pb(0);
for (int i=last[x];i;i=e[i].next)
if (!vis[e[i].to]) get(e[i].to,x,e[i].w);
t1.build(pts,dis);
for (int i=last[x];i;i=e[i].next) if (!vis[e[i].to])
{
rt=0;sum=size[e[i].to];get_root(e[i].to,x);
work(rt);
}
}
void solve()
{
rt=0;sum=n;get_root(1,0);
work(rt);
}
}t2;
int find(int x)
{
return f[x]==x?x:f[x]=find(f[x]);
}
bool cmp1(data x,data y)
{
return x.w<y.w;
}
LL kruskal()
{
std::sort(edg+1,edg+cnt0+1,cmp1);
for (int i=1;i<=n;i++) f[i]=i;
LL ans=0;
for (int i=1;i<=cnt0;i++)
{
int x=find(edg[i].x),y=find(edg[i].y);
if (x!=y) f[x]=y,ans+=edg[i].w;
}
return ans;
}
int main()
{
scanf("%d",&n);
t1.init();
t2.init();
t1.dfs(1);
t2.solve();
printf("%lld\n",kruskal());
return 0;
}