题目大意:给一颗边有权的树,求有多少路径满足边权最大值减去最小值不超过k,1e5。
题解:用BIT维护点分治过程中的二维数点或者直接LCT即可。
点分治
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<vector>
#include<utility>
#define lint long long
#define gc getchar()
#define mp make_pair
#define fir first
#define sec second
#define pb push_back
#define lb(x) (x&-x)
#define N 100010
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
struct edges{
int to,pre,wgt;
}e[N<<1];int k,h[N],etop,sz[N],vis[N],lst[N],lt;
lint ans;vector<pii> v[N],tmp;vector<int> s;
inline int add_edge(int u,int v,int w) { return e[++etop].to=v,e[etop].wgt=w,e[etop].pre=h[u],h[u]=etop; }
inline int getid(int x) { return lower_bound(s.begin(),s.end(),x)-s.begin()+1; }
int getsz(int x,int fa=0)
{
sz[x]=1;
for(int i=h[x],y;i;i=e[i].pre)
if((y=e[i].to)!=fa&&!vis[y]) sz[x]+=getsz(y,x);
return sz[lst[++lt]=x];
}
inline int getrt(int x)
{
for(int i=1,fsz=sz[x],t=fsz;i<=lt;i++)
{
int y=lst[i],ysz=fsz-sz[y];
for(int j=h[y],z;j;j=e[j].pre)
if(!vis[z=e[j].to]&&sz[e[j].to]<sz[y]) ysz=max(ysz,sz[z]);
if(ysz<t) t=ysz,x=y;
}
return lt=0,x;
}
int getmsg(int x,int fa,int z,int mn,int mx)
{
if(mx-mn>k) return 0;v[z].pb(mp(mn,mx)),ans++;
for(int i=h[x],y;i;i=e[i].pre)
if(!vis[y=e[i].to]&&e[i].to!=fa)
getmsg(y,x,z,min(mn,e[i].wgt),max(mx,e[i].wgt));
return 0;
}
struct BIT{
int n,c,p[N*3],y[N*3],v[N*3];BIT() { memset(v,0,sizeof v);n=0; }
inline int init(int _n) { return n=_n; }
inline int update(int x,int val) { for(p[++c]=x,y[c]=val;x<=n;x+=lb(x)) v[x]+=val;return 0; }
inline int query(int x) { int ans=0;for(;x;x-=lb(x)) ans+=v[x];return ans; }
inline int query(int x,int y) { return query(y)-query(x-1); }
inline int clear() { for(int i=1;i<=c;i++) update(p[i],-y[i]),c--;return c=0; }
}b;
inline lint solve(vector<pii> v)
{
tmp.clear();lint ans=0ll;int n=0;
for(int i=0;i<(int)v.size();i++) tmp.pb(mp(v[i].sec,v[i].fir));
sort(v.begin(),v.end()),sort(tmp.begin(),tmp.end());
for(int i=0;i<(int)v.size();i++) n=max(n,v[i].sec=getid(v[i].sec));
b.init(n);
// for(int i=0;i<(int)v.size();i++) debug(i)sp,debug(v[i].fir)sp,debug(v[i].sec)ln;
for(int i=0,x=-1,y=-1;i<(int)tmp.size();i++)
{
while(y+1<(int)v.size()&&v[y+1].fir<=tmp[i].fir+k)
b.update(v[++y].sec,1);//,debug(v[y].sec)ln;
while(x+1<(int)v.size()&&v[x+1].fir<tmp[i].fir-k)
b.update(v[++x].sec,-1);//,debug(v[x].sec)ln;
// debug(i)sp,debug(x)sp,debug(y)ln;
// debug(getid(tmp[i].sec-k))sp,debug(getid(tmp[i].sec+k))ln;
ans+=b.query(getid(tmp[i].sec-k),min(n,getid(tmp[i].sec+k)));
}
return b.clear(),(ans-(int)v.size())/2;
}
int getans(int x)
{
lt=0,getsz(x),x=getrt(x),vis[x]=1;
for(int i=h[x],y;i;i=e[i].pre) if(!vis[y=e[i].to])
v[y].clear(),getmsg(y,x,y,e[i].wgt,e[i].wgt),ans-=solve(v[y]);//,debug(x)sp,debug(y)sp,debug(solve(v[y]))ln;//cur,fa,an,mn,mx
v[x].clear();
for(int i=h[x],y;i;i=e[i].pre) if(!vis[y=e[i].to])
for(int j=0;j<(int)v[y].size();j++) v[x].pb(v[y][j]);
ans+=solve(v[x]);
for(int i=h[x],y;i;i=e[i].pre) if(!vis[y=e[i].to]) getans(y);
return 0;
}
int main()
{
int n=inn();k=inn();
for(int i=1,u,v,c;i<n;i++)
u=inn(),v=inn(),c=inn(),add_edge(u,v,c),add_edge(v,u,c),
s.push_back(c),s.push_back(c-k),s.push_back(c+k);
sort(s.begin(),s.end()),s.erase(unique(s.begin(),s.end()),s.end());
getans(1);return !printf("%lld\n",ans);
}
LCT
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define gc getchar()
#define lint long long
#define N 100010
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
int ch[N][2],fa[N],pf[N],sz[N],val[N],rev[N];
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
inline int gw(int x) { return ch[fa[x]][1]==x; }
inline int push_up(int x) { return sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+val[x]; }
inline int setc(int x,int y,int z) { if(!x) return fa[y]=0;ch[x][z]=y;if(y) fa[y]=x;return push_up(x); }
inline int rotate(int x)
{ int y=fa[x],z=fa[y],a=gw(x),b=gw(y),c=ch[x][a^1];return swap(pf[x],pf[y]),setc(y,c,a),setc(x,y,a^1),setc(z,x,b); }
inline int push_down(int x)
{
if(!rev[x]) return 0;
if(ch[x][0]) rev[ch[x][0]]^=1;
if(ch[x][1]) rev[ch[x][1]]^=1;
swap(ch[x][0],ch[x][1]);return rev[x]=0;
}
inline int all_down(int x) { return (fa[x]?all_down(fa[x]):0),(rev[x]?push_down(rev[x]):0); }
inline int splay(int x)
{ for(all_down(x);fa[x];rotate(x)) if(fa[fa[x]]) rotate(gw(x)==gw(fa[x])?fa[x]:x);return 0; }
inline int expose(int x)
{ splay(x);int y=ch[x][1];if(!y) return 0;return ch[x][1]=fa[y]=0,pf[y]=x,val[x]+=sz[y],push_up(x); }
inline int splice(int x)
{ splay(x);int y=pf[x];if(!y) return 0;return expose(y),splay(y),pf[x]=0,val[y]-=sz[x],setc(y,x,1),1; }
inline int access(int x) { expose(x);while(splice(x));return 0; }
inline int evert(int x) { return access(x),splay(x),rev[x]^=1; }
inline int link(int x,int y) { return evert(x),splay(x),evert(y),splay(y),setc(x,y,1); }
inline int cut(int x,int y) { return evert(x),access(y),splay(x),setc(x,0,1),setc(0,y,0); }
inline int gsz(int x) { return evert(x),splay(x),sz[x]; }
struct P { int u,v,w;inline bool operator<(const P &p)const { return w<p.w; } }p[N];
int main()
{
int n=inn(),k=inn();lint ans=0ll;for(int i=1;i<=n;i++) sz[i]=val[i]=1;
for(int i=1;i<n;i++) p[i].u=inn(),p[i].v=inn(),p[i].w=inn();
sort(p+1,p+n);
for(int i=1,j=1;i<n;cut(p[i].u,p[i].v),i++)
{
while(j<n&&p[j].w<=p[i].w+k)
ans+=(lint)gsz(p[j].u)*gsz(p[j].v),link(p[j].u,p[j].v),j++;
}
return !printf("%lld\n",ans);
}