传送门
题解:
首先你还是要尝试各种做法,包括点分治,链剖等等。
反正考虑点对统计的做法多得一批,也有很多能够做到 O ( n ⋅ p o l y ( log n ) ) O(n\cdot poly(\log n)) O(n⋅poly(logn)),至于是 log \log log几次方就看具体做法了。
但是换个思路可以做到一个 log \log log。
对于一个点,考虑统计有多少个点能够到达它,答案实际上是所有覆盖它的链的并形成的连通子树的大小。
其实换个说法,就是所有覆盖了它的链的端点,形成的点集的虚树上的总边长(注意这里的虚树是严格虚树,只保留所有关键点和关键点之间的LCA,根节点是可以丢掉的)。
还是考虑求加上 1 1 1号点的虚树,最后把所有点的LCA的到 1 1 1的距离减掉就行了。
由于维护的边长之和,可以直接考虑DFS上LCA差分。
然后利用DFS序合并区间虚树的线段树做法就可以维护了,不过这道题线段树需要动态开点同时支持一下线段树合并。
这里简单讲一下合并的同时处理差分,将点按照DFS序作为下标扔到线段树里面,那么合并两个区间只需要考虑左区间的最后一个点和右区间的第一个点的LCA的影响就可以了。
然后把求LCA搞成 O ( 1 ) O(1) O(1)就能做到 O ( n log n ) O(n\log n) O(nlogn)了。
代码:
#include<bits/stdc++.h>
#define ll long long
#define re register
#define cs const
namespace IO{
inline char gc(){
static cs int Rlen=1<<22|1;
static char buf[Rlen],*p1,*p2;
return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;
}
template<typename T>
inline T get(){
char c;T num;
while(!isdigit(c=gc()));num=c^48;
while(isdigit(c=gc()))num=(num+(num<<2)<<1)+(c^48);
return num;
}
inline int gi(){return get<int>();}
}
using namespace IO;
using std::cerr;
using std::cout;
cs int N=1e5+7;
int n,m;
std::vector<int> G[N];
int fa[N],d[N],in[N],dfn;
namespace ST{
int st[20][N<<1],ps[N],dfc,Log[N<<1];
void dfs(int u,int p){
fa[u]=p,d[u]=d[p]+1;in[u]=++dfn,st[0][ps[u]=++dfc]=u;
for(int re v:G[u])if(v!=p)dfs(v,u),st[0][++dfc]=u;
}
inline void init(){
dfs(1,0);for(int re i=2;i<=dfc;++i)Log[i]=Log[i>>1]+1;
for(int re i=1;i<=Log[dfc];++i)
for(int re j=1;j+(1<<i)-1<=dfc;++j)
st[i][j]=d[st[i-1][j]]<d[st[i-1][j+(1<<i-1)]]?st[i-1][j]:st[i-1][j+(1<<i-1)];
}
inline int LCA(int u,int v){
if(!u||!v)return 0;int l=ps[u],r=ps[v];
if(l>r)std::swap(l,r);int t=Log[r-l+1];
return d[st[t][l]]<d[st[t][r-(1<<t)+1]]?st[t][l]:st[t][r-(1<<t)+1];
}
}
using ST::LCA;
int rt[N];
namespace SGT{
cs int N=::N*80;
int lc[N],rc[N],ct[N],s[N],L[N],R[N],tot;
inline void pushup(int u){
s[u]=s[lc[u]]+s[rc[u]]-d[LCA(R[lc[u]],L[rc[u]])];
L[u]=L[L[lc[u]]?lc[u]:rc[u]];
R[u]=R[R[rc[u]]?rc[u]:lc[u]];
}
inline void ins(int &u,int l,int r,int p,int v){
if(!u)u=++tot;if(l==r){ct[u]+=v,s[u]=ct[u]?d[p]:0,L[u]=R[u]=ct[u]?p:0;return ;}
int mid=l+r>>1;(in[p]<=mid)?ins(lc[u],l,mid,p,v):ins(rc[u],mid+1,r,p,v);pushup(u);
}
inline void merge(int &u,int v,int l=1,int r=n){
if(!u||!v){u|=v;return ;}int mid=l+r>>1;
if(l==r){if(!ct[u]&&ct[v])u=v;else ct[u]+=ct[v];return ;}
merge(lc[u],lc[v],l,mid);
merge(rc[u],rc[v],mid+1,r);pushup(u);
}
inline int query(int u){return s[u]-d[LCA(L[u],R[u])];}
inline void dfs(int u){
if(!u)return ;
if(ct[u]){cout<<L[u]<<" ";return ;}
dfs(lc[u]);dfs(rc[u]);
}
}
ll ans;
std::vector<int> del[N];
void dfs(int u,int p){
for(int re v:G[u])if(v!=p)
dfs(v,u),SGT::merge(rt[u],rt[v]);
for(int re v:del[u])
SGT::ins(rt[u],1,n,v,-1);
ans+=SGT::query(rt[u]);
}
signed main(){
#ifdef zxyoi
freopen("lang.in","r",stdin);
#endif
n=gi(),m=gi();
for(int re i=1;i<n;++i){
int u=gi(),v=gi();
G[u].push_back(v);
G[v].push_back(u);
}ST::init();assert(dfn==n);
for(int re i=1;i<=m;++i){
int u=gi(),v=gi(),p=LCA(u,v);
SGT::ins(rt[u],1,n,u,1);SGT::ins(rt[u],1,n,v,1);
SGT::ins(rt[v],1,n,u,1);SGT::ins(rt[v],1,n,v,1);
del[p].push_back(u),del[p].push_back(v);
if(fa[p])del[fa[p]].push_back(u),del[fa[p]].push_back(v);
}
dfs(1,0);cout<<ans/2<<"\n";
return 0;
}