简化后的题目描述
给出两棵树
T
1
,
T
2
T1,T2
T1,T2
定义
f
(
T
,
x
,
y
)
f(T,x,y)
f(T,x,y) 定义为
T
T
T 中点
x
x
x 和点
y
y
y 路径上各边权的最大值
求
∑
i
=
1
N
−
1
∑
j
=
i
+
1
N
f
(
T
1
,
i
,
j
)
∗
f
(
T
2
,
i
,
j
)
\sum_{i=1}^{N-1}\sum_{j=i+1}^{N}f(T1,i,j)*f(T2,i,j)
∑i=1N−1∑j=i+1Nf(T1,i,j)∗f(T2,i,j)
题解
考虑从小到大枚举
T
1
T1
T1中的边,然后就是计算,两个点集的在
T
2
T2
T2上的
f
f
f和
对于
T
2
T2
T2建出其对应的
k
r
u
s
k
a
l
kruskal
kruskal重构树(设
v
a
l
(
i
)
val(i)
val(i)为
k
r
u
s
k
a
l
kruskal
kruskal重构树中点
i
i
i的点权),
那么
f
(
T
2
,
i
,
j
)
=
v
a
l
(
l
c
a
(
i
,
j
)
)
f(T2,i,j)=val(lca(i,j))
f(T2,i,j)=val(lca(i,j))
然后将
i
i
i与
f
a
t
h
e
r
i
father_i
fatheri的边的边权设为
v
a
l
(
i
)
−
v
a
l
(
f
a
t
h
e
r
i
)
val(i)-val(father_i)
val(i)−val(fatheri)
根据
d
i
s
(
a
,
b
)
=
d
e
p
(
a
)
+
d
e
p
(
b
)
−
2
d
e
p
(
l
c
a
(
a
,
b
)
)
dis(a,b)=dep(a)+dep(b)-2dep(lca(a,b))
dis(a,b)=dep(a)+dep(b)−2dep(lca(a,b))可以得到
d
e
p
(
l
c
a
(
a
,
b
)
)
=
1
2
(
d
i
s
(
a
,
b
)
−
d
e
p
(
a
)
−
d
e
p
(
b
)
)
dep(lca(a,b))=\frac{1}{2}(dis(a,b)-dep(a)-dep(b))
dep(lca(a,b))=21(dis(a,b)−dep(a)−dep(b))
那么我们就将一条路径的
m
a
x
max
max转化为一条路径的和
如果考虑用动态点分治维护,启发式合并就可以做到
O
(
n
l
o
g
2
n
)
O(nlog^2n)
O(nlog2n)
继续优化
容易发现点分树是一种很好的分治结构,如果我们知道了两个点集所形成的两颗点分树,那么是否可以通过类似线段树合并的方法合并
然后就做到了
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
另外一种方法
对于
T
1
T1
T1我们做点分治,考虑算两个点集之间的答案。
然后只用建出两个点集的虚树,在虚树上
d
p
dp
dp。
当建虚树做到
O
(
n
)
O(n)
O(n)时,时间复杂度为
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
可以先做下去,然后回溯回来时,使用归并排序,这样排序就做到了
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
不过这样的方法常数会有点大。
另外第二种方法
可以用全局平衡二叉树来代替点分树求两点之间距离,并且全局平衡二叉树也是一种很好的分治结构,也能够类似线段树合并的方法合并
时间复杂度
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
code
#include<bits/stdc++.h>
#define fo(i,a,b)for(int i=a,_e=b;i<=_e;++i)
#define fd(i,a,b)for(int i=b,_e=a;i>=_e;--i)
#define link(x,y,w)(e[x].push_back((edge){y,w}),e[y].push_back((edge){x,w}))
#define ll long long
using namespace std;
const int N=2e5+5,mo=998244353;
int n,m,x,y,w,z,all;
int si[N],sz[N],D,de[N],d[20][N],zfa[N];
int f[N],fa[N],nn,v[N],th[N],s[N];
int rt[N],ts,ans,su;
bool bz[N];
struct nod{int x,y,w;}a[N],b[N];
struct edge{int y,w;};
struct tree{int su,ct,s[3];}t[N*19];
vector<edge>e[N];
bool cmp(nod a,nod b){return a.w<b.w;}
int F(int x){return x==f[x]?x:f[x]=F(f[x]);}
void get(int x){
bz[x]=1;si[x]=1;sz[x]=0;
for(auto i:e[x])if(!bz[i.y])
get(i.y),si[x]+=si[i.y],sz[x]=max(sz[x],si[i.y]);
sz[x]=max(sz[x],all-si[x]);bz[x]=0;
if(sz[x]<sz[z])z=x;
}
void dfs(int x,int dp){
bz[x]=1;d[D][x]+=dp;d[D+1][x]-=dp;++all;
for(auto i:e[x])if(!bz[i.y])
dfs(i.y,dp+i.w);
bz[x]=0;
}
void fz(int x){
bz[x]=1;de[x]=D;
for(auto i:e[x])if(!bz[i.y]){
all=0;dfs(i.y,i.w);si[i.y]=all;
}
++D;
for(auto i:e[x])if(!bz[i.y]){
all=si[i.y];z=0;get(i.y);
zfa[z]=x;th[z]=s[x]++;fz(z);
}
--D;
}
void mer(int &v,int v2){
if(!v2||!v){v|=v2;return;}
su=((ll)t[v].ct*t[v2].su+(ll)t[v].su*t[v2].ct+su)%mo;
t[v].su=(t[v].su+t[v2].su)%mo;
t[v].ct=(t[v].ct+t[v2].ct)%mo;
fo(i,0,2)mer(t[v].s[i],t[v2].s[i]);
}
int main(){
scanf("%d%d",&n,&m);
fo(i,1,m)scanf("%d%d%d",&x,&y,&w),a[i]=(nod){x,y,w};
fo(i,1,m)scanf("%d%d%d",&x,&y,&w),b[i]=(nod){x,y,w};
sort(a+1,a+m+1,cmp);
sort(b+1,b+m+1,cmp);
fo(i,1,n*2)f[i]=i;nn=n;
fo(i,1,m){
x=F(b[i].x);y=F(b[i].y);
if(x!=y){
v[++nn]=b[i].w;
f[x]=f[y]=fa[x]=fa[y]=nn;
}
}
fo(i,1,nn)v[i]-=v[fa[i]];
fa[nn]=++nn;
fo(i,1,nn-1)link(i,fa[i],v[i]);
sz[0]=nn+1;all=nn;get(1);fz(z);
fo(i,1,n){
f[i]=i;
++ts;t[ts].ct=1;t[ts].su=d[de[x]][i];
for(x=i,D=de[x];D;x=zfa[x],--D)
++ts,t[ts].su=d[D-1][i],t[ts].ct=1,t[ts].s[th[x]]=ts-1;
rt[i]=ts;
}
fo(i,1,m){
x=F(a[i].x);y=F(a[i].y);
if(x!=y){
f[y]=x;
su=0;
mer(rt[x],rt[y]);
ans=((ll)a[i].w*su+ans)%mo;
}
}
cout<<(ll)(mo-ans)*(mo+1>>1)%mo;
}