传送门
解析:
考场上写完 T 1 T1 T1后,本来说睡一会起来写 T 2 T2 T2,结果刚趴下五分钟就想出了点分治写法,一发过了大样例。结果最后卡常只剩80分。。。
结果标程是 K r u s k a l Kruskal Kruskal重构树上离线树状数组。。。 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n),我的点分治明明也是 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)。。。
可怕的是标程在OJ上直接RE,我的点分治在OJ上被卡了更多的常数QWQ。
结果是NOI Linux下跑的比标算慢的 O ( n l o g n ) O(nlogn) O(nlogn)线段树合并在 O J OJ OJ上AC了。。。%%%ldxoi
思路:
首先,一看到求的东西,显然是瓶颈路,那么 K r u s k a l Kruskal Kruskal跑不了了,考虑怎么统计答案。
解法1:完爆标算的线段树合并
考虑将原来的颜色权值的限制修改一下,不妨设 c u > c v c_u>c_v cu>cv,那么限制就是 c u − c v ≥ L c_u-c_v\geq L cu−cv≥L ,就是 c u > c v + L − 1 c_u > c_v+L-1 cu>cv+L−1,所以我们直接来最暴力的做法,每个点维护两个权值 c u c_u cu和 c u + L − 1 c_u+L-1 cu+L−1,离散化(节约空间,方便比较大小),然后每个点建一颗动态开点式的线段树,方便之后合并。
然后 K r u s k a l Kruskal Kruskal求一下最小生成树,显然我们连接两个连通块的时候所用的边就是这两个连通块中的点相互到达需要经过的最小的最大边。那么我们只需要查询两个连通块中满足条件的点对有多少就行了。这个直接线段树上按照线段树结构跑一遍就能统计答案了。
合并连通块的时候将两个线段树按照结构合并就行了,这样的合并复杂度只和深度有关,所以总的复杂度是 O ( n l o g n ) O(nlogn) O(nlogn)。
解法2:被卡常的点分治
心态真的炸了啊。。。。
为什么要卡点分治那么一点点常数啊,所有超时的点在NOI Linux 下都只差不到0.几s,真的卡的我心态爆炸啊啊啊!!!!
算了还是讲一下怎么做吧。
首先求出来最小生成树,在最小生成树上点分治。
点分治的时候每个点记录它到当前分治中心路径上的最大值。
然后,大常数预警!
维护两个树状数组,一个记录距离小于等于 d i s t u dist_u distu的满足条件的节点的个数,一个记录距离大于 d i s t u dist_u distu的后缀和。
那么直接树状数组查询+修改就可以被卡常了
两份代码都放在下面了,有需要自取
代码(线段树合并):
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
#define gc getchar
#define pc putchar
#define cs const
inline int getint(){
re int num;
re char c;
while(!isdigit(c=gc()));num=c^48;
while(isdigit(c=gc()))num=(num<<1)+(num<<3)+(c^48);
return num;
}
cs int N=200005,M=500005;
struct edge{
int u,v,w;
friend bool operator<(cs edge &a,cs edge &b){
return a.w<b.w;
}
}E[M];
int fa[N];
inline int getfa(re int x){
while(x^fa[x])x=fa[x]=fa[fa[x]];
return x;
}
int n,m,L;
int root[N][2],lc[N*60],rc[N*60],tot,siz[N*60];
int val[N][2];
int a[N<<1],cnt;
ll ans;
inline void update(int &k,int l,int r,cs int &val){
if(!k)k=++tot;
++siz[k];
if(l==r)return ;
int mid=(l+r)>>1;
if(val<=mid)update(lc[k],l,mid,val);
else update(rc[k],mid+1,r,val);
}
inline ll query(int lt,int rt,int l,int r){
if(!lt||!rt||l==r)return 0;
int mid=(l+r)>>1;
return query(lc[lt],lc[rt],l,mid)+query(rc[lt],rc[rt],mid+1,r)+1ll*siz[lc[lt]]*siz[rc[rt]];
}
inline int merge(int lt,int rt,int l,int r){
if(!lt||!rt)return lt+rt;
siz[lt]+=siz[rt];
if(l==r)return lt;
int mid=(l+r)>>1;
lc[lt]=merge(lc[lt],lc[rt],l,mid);
rc[lt]=merge(rc[lt],rc[rt],mid+1,r);
return lt;
}
signed main(){
n=getint();
m=getint();
L=getint();
for(int re i=1;i<=n;++i){
a[++cnt]=val[i][0]=getint();
a[++cnt]=val[i][1]=val[i][0]+L-1;
}
sort(a+1,a+cnt+1);
cnt=unique(a+1,a+cnt+1)-a-1;
for(int re i=1;i<=n;++i){
val[i][0]=lower_bound(a+1,a+cnt+1,val[i][0])-a;
val[i][1]=lower_bound(a+1,a+cnt+1,val[i][1])-a;
update(root[i][0],1,cnt,val[i][0]);
update(root[i][1],1,cnt,val[i][1]);
fa[i]=i;
}
for(int re i=1;i<=m;++i){
E[i].u=getint();
E[i].v=getint();
E[i].w=getint();
}
sort(E+1,E+m+1);
for(int re i=1;i<=m;++i){
int u=getfa(E[i].u),v=getfa(E[i].v);
if(u==v)continue;
fa[v]=u;
if(L)ans+=1ll*E[i].w*(query(root[u][1],root[v][0],1,cnt)+query(root[v][1],root[u][0],1,cnt));
else ans+=1ll*E[i].w*siz[root[u][1]]*siz[root[v][1]];
root[u][1]=merge(root[u][1],root[v][1],1,cnt);
root[u][0]=merge(root[u][0],root[v][0],1,cnt);
}
cout<<ans;
return 0;
}
代码(点分治):
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
#define gc get_char
#define pc putchar
#define cs const
inline char get_char(){
static cs int Rlen=1<<18|1;
static char buf[Rlen],*p1=buf,*p2=buf;
return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;
}
inline int getint(){
re int num;
re char c;
while(!isdigit(c=gc()));num=c^48;
while(isdigit(c=gc()))num=(num<<1)+(num<<3)+(c^48);
return num;
}
cs int N=200005,M=500005;
int last[N],nxt[N<<1],to[N<<1],ecnt;
int w[N<<1];
inline void addedge(cs int &u,cs int &v,cs int &val){
nxt[++ecnt]=last[u],last[u]=ecnt,to[ecnt]=v,w[ecnt]=val;
nxt[++ecnt]=last[v],last[v]=ecnt,to[ecnt]=u,w[ecnt]=val;
}
struct edge{
int u,v,w;
friend bool operator<(cs edge &a,cs edge &b){
return a.w<b.w;
}
}E[M];
int fa[N],col[N];
inline int getfa(re int x){
while(x^fa[x])x=fa[x]=fa[fa[x]];
return x;
}
int n,m,L;
ll ans;
bool ban[N];
int siz[N],dist[N];
int maxn,total,G;
void find_G(cs int &u,cs int &fa){
siz[u]=1;
re int mx=0;
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
if(v==fa||ban[v])continue;
find_G(v,u);
siz[u]+=siz[v];
mx=max(mx,siz[v]);
}
mx=max(mx,total-siz[u]);
if(mx<maxn)maxn=mx,G=u;
}
int all[N],DIS[N],tot1,tot2;
int cnt[N];
ll bit[N];
void getdis(cs int &u,cs int &fa){
DIS[++tot2]=dist[u];
all[++tot1]=u;
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
if(ban[v]||v==fa)continue;
dist[v]=max(dist[u],w[e]);
getdis(v,u);
}
}
inline bool sort_col(cs int &a,cs int &b){
return col[a]<col[b];
}
inline void init(){
sort(all,all+tot1,sort_col);
sort(DIS,DIS+tot2);
tot2=unique(DIS,DIS+tot2)-DIS;
memset(bit,0,8*(tot2+2));
memset(cnt,0,4*(tot2+2));
}
#define lowbit(x) (x&(-x))
inline void addcnt(re int pos,cs int &val){
pos=upper_bound(DIS,DIS+tot2,pos)-DIS+1;
for(;pos<tot2+2;pos+=lowbit(pos))cnt[pos]+=val;
}
inline int querycnt(re int pos,re int res=0){//询问前缀和
pos=upper_bound(DIS,DIS+tot2,pos)-DIS+1;
for(;pos;pos-=lowbit(pos))res+=cnt[pos];
return res;
}
inline void addbit(re int pos,cs int &val){
pos=lower_bound(DIS,DIS+tot2,pos)-DIS+1;
for(;pos;pos-=lowbit(pos))bit[pos]+=val;
}
inline ll querybit(re int pos,re ll res=0){//询问后缀和
pos=upper_bound(DIS,DIS+tot2,pos)-DIS+1;
for(;pos<tot2+2;pos+=lowbit(pos))res+=bit[pos];
return res;
}
inline void calc(cs int &u,cs int &dis,cs int &f){
dist[u]=dis;
tot1=tot2=-1;
getdis(u,0);
++tot1,++tot2;
init();
for(int re i=0,j=0;i<tot1;++i){
int u=all[i],v=all[j];
while(col[v]+L<=col[u]&&j<i){
addcnt(dist[v],1);
addbit(dist[v],dist[v]);
v=all[++j];
}
ans+=1ll*querycnt(dist[u])*dist[u]*f;//相同的统计在cnt 里面
ans+=querybit(dist[u])*f;
}
}
inline void solve(int u){
calc(u,0,1);
ban[u]=true;
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
if(ban[v])continue;
calc(v,w[e],-1);
maxn=total=siz[v];
find_G(v,u);
solve(G);
}
}
signed main(){
n=getint();
m=getint();
L=getint();
for(int re i=1;i<=n;++i){
fa[i]=i;
col[i]=getint();
}
for(int re i=1;i<=m;++i){
E[i].u=getint();
E[i].v=getint();
E[i].w=getint();
}
sort(E+1,E+m+1);
for(int re i=1;i<=m;++i){
int u=getfa(E[i].u),v=getfa(E[i].v);
if(u==v)continue;
addedge(u,v,E[i].w);
fa[v]=u;
++total;
if(total==n-1)break;
}
maxn=total=n;
find_G(1,0);
solve(G);
cout<<ans;
return 0;
}