题意
有一棵无根树,树上每条边都有一个时间区间[li,ri],表示这条边在这个时间段内连通.
问树上一共有多少条路径,至少在某一个时刻连通.路径上至少要有一条边
数据范围
1 ≤ n ≤ 2 e 5 , 1 ≤ l i , r i ≤ 1 e 9 1\le n \le 2e5,1\le li,ri \le 1e9 1≤n≤2e5,1≤li,ri≤1e9
解法
可以使用LCT但是这里不谈
考虑点分,然后就是对于每个分治中心,其不同的路径的合并问题.可以转化为对于一些线段,查询一条线段与其中几个有交.这个可以用全集减去不相交的部分,就是右端点比查询线段的左端点小的或者左端点比查询线段右端点大的.这个可以随便用什么数据结构维护(我用的是数组).
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=2e5+5;
inline int read(){
char c=getchar();int t=0,f=1;
while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
return t*f;
}
int n;
struct edge{
int v,p,l,r;
}e[maxn<<1];
int h[maxn],cnt;
inline void add(int a,int b,int l,int r){
e[++cnt].p=h[a];
e[cnt].v=b;
e[cnt].l=l;e[cnt].r=r;
h[a]=cnt;
e[++cnt].p=h[b];
e[cnt].v=a;
e[cnt].l=l;e[cnt].r=r;
h[b]=cnt;
}
int sz[maxn],rt,mx[maxn],tot,vis[maxn];
void dfs1(int u,int fa){
sz[u]=1;mx[u]=0;
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa||vis[v])continue;
dfs1(v,u);
sz[u]+=sz[v];
mx[u]=max(mx[u],sz[v]);
}
mx[u]=max(mx[u],tot-sz[u]);
if(rt==0||mx[rt]>mx[u])rt=u;
}
struct node{
int l,r;
}dis[maxn],alfa[maxn],beta[maxn],a2[maxn];
bool cmp1(node a,node b){
return a.l<b.l;
}
bool cmp2(node a,node b){
return a.r<b.r;
}
int tim;
void getdis(int u,int fa,int l,int r,int f){
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa||vis[v])continue;
int l1=max(l,e[i].l),r1=min(r,e[i].r);
if(l1<=r1){
//mp[f][v].l=l1;mp[f][v].r=r1;
dis[++tim].l=l1;dis[tim].r=r1;
getdis(v,u,l1,r1,f);
}
}
}
int ans,lst;
inline int finda(int x){
int l=1,r=lst,ans=0;
while(l<=r){
int mid=l+r>>1;
if(alfa[mid].l>x){ans=lst-mid+1;r=mid-1;}
else l=mid+1;
}
return ans;
}
inline int findb(int x){
int l=1,r=lst,ans=0;
while(l<=r){
int mid=l+r>>1;
if(beta[mid].r<x){ans=mid;l=mid+1;}
else r=mid-1;
}
return ans;
}
int ans2;
void calc(int u){
lst=0;//printf("%d\n",u);
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;tim=0;
if(vis[v])continue;
//mp[u][v].l=e[i].l;mp[u][v].r=e[i].r;
dis[++tim].l=e[i].l;dis[tim].r=e[i].r;
getdis(v,u,e[i].l,e[i].r,u);
//printf("%d\n",tim);
ans+=tim;
for(int j=1;j<=tim;j++){
int tmpl=findb(dis[j].l);
int tmpr=finda(dis[j].r);
ans2=ans2+lst-tmpl-tmpr;
//? printf("%lld %lld %lld %lld %lld\n",dis[j].l,dis[j].r,lst,tmpl,tmpr);
}
for(int j=lst+1;j<=lst+tim;j++){alfa[j]=dis[j-lst];beta[j]=dis[j-lst];}
sort(alfa+lst+1,alfa+1+lst+tim,cmp1);
sort(beta+lst+1,beta+1+lst+tim,cmp2);
int p1=1,p2=lst+1,tot=0;
while(p1<=lst&&p2<=lst+tim){
if(alfa[p1].l<alfa[p2].l){tot++;a2[tot]=alfa[p1];p1++;}
else{tot++;a2[tot]=alfa[p2];p2++;}
}
while(p1<=lst){tot++;a2[tot]=alfa[p1];p1++;}
while(p2<=lst+tim){tot++;a2[tot]=alfa[p2];p2++;}
for(int i=1;i<=tot;i++)alfa[i]=a2[i];
p1=1,p2=lst+1,tot=0;
while(p1<=lst&&p2<=lst+tim){
if(beta[p1].r<beta[p2].r){tot++;a2[tot]=beta[p1];p1++;}
else{tot++;a2[tot]=beta[p2];p2++;}
}
while(p1<=lst){tot++;a2[tot]=beta[p1];p1++;}
while(p2<=lst+tim){tot++;a2[tot]=beta[p2];p2++;}
for(int i=1;i<=tot;i++)beta[i]=a2[i];
lst=lst+tim;
/*sort(alfa+1,alfa+1+lst,cmp1);暴力排序的复杂度是假的,需要归并
sort(beta+1,beta+1+lst,cmp2);*/
}
}
void solve(int u){
calc(u);vis[u]=1;
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(vis[v])continue;
tot=sz[v];
rt=0;
dfs1(v,0);
solve(rt);
}
}
signed main(){
//freopen("3.in","r",stdin);
//freopen("3b.out","w",stdout);
n=read();
for(int i=1;i<n;i++){
int a=read(),b=read(),l=read(),r=read();
add(a,b,l,r);
}tot=n;
dfs1(1,0);
solve(rt);
printf("%lld\n",ans+ans2);
return 0;
}