题意
一眼最小生成树,没错吧。只不过这道题用考kruskal只能得40pts。所以考虑优化。又由于kruskal不好优化,所以我们考虑优化另一个耳熟目染的算法:prim。
然后呢,我们需要优化的是每次找未加入连通块的最短距离,我们可以考虑建立线段树
t
3
t3
t3,用来存两个点都在
l
−
r
l-r
l−r时最短的转移距离。
那么很明显有三种更新方式:
- x,y,都在 l , m i d l,mid l,mid区间。
- x,y都在 m i d + 1 , r mid+1,r mid+1,r区间。
- x,y跨过了
m
i
d
mid
mid。
读题发现,2个点之间边权只与每个点的 a [ i ] + p × i 和 a [ i ] − p × i a[i]+p \times i和a[i]-p \times i a[i]+p×i和a[i]−p×i决定。我们开一个2个线段树存2个集合点的4个值,就能来更新线段树 t 3 t3 t3了,怎么更新以及细节看代码.
code:
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=3e6+10;
const int inf=1e18;
int val1[4*N][2],val2[4*N][2],val3[4*N],id1[4*N][2],id2[4*N][2],id3[4*N];
int n,p,a[N],ans;
int Read()
{
int x=0,f=1;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-') f=-1;
ch=getchar();
}
while(isdigit(ch))
{
x=(x<<1)+(x<<3)+ch-'0';
ch=getchar();
}
return f*x;
}
void update1(int u,int l,int r,int id,int val,int f)
{
if(l==r)
{
id1[u][f]=id;
val1[u][f]=val;
return;
}
int mid=(l+r)>>1;
if(id<=mid) update1(u*2,l,mid,id,val,f);
else update1(u*2+1,mid+1,r,id,val,f);
if(val1[u*2][f]<val1[u*2+1][f])
{
val1[u][f]=val1[u*2][f];
id1[u][f]=id1[u*2][f];
}
else
{
val1[u][f]=val1[u*2+1][f];
id1[u][f]=id1[u*2+1][f];
}
}
void update2(int u,int l,int r,int id,int val,int f)
{
if(l==r)
{
id2[u][f]=id;
val2[u][f]=val;
return;
}
int mid=(l+r)>>1;
if(id<=mid) update2(u*2,l,mid,id,val,f);
else update2(u*2+1,mid+1,r,id,val,f);
if(val2[u*2][f]<val2[u*2+1][f])
{
val2[u][f]=val2[u*2][f];
id2[u][f]=id2[u*2][f];
}
else
{
val2[u][f]=val2[u*2+1][f];
id2[u][f]=id2[u*2+1][f];
}
}
void update3(int u,int l,int r,int id)
{
if(l==r) return;
int mid=(l+r)>>1;
if(id<=mid) update3(u*2,l,mid,id);
else update3(u*2+1,mid+1,r,id);
val3[u]=val3[u*2],id3[u]=id3[u*2];
if(val3[u*2+1]<val3[u])
{
val3[u]=val3[u*2+1];
id3[u]=id3[u*2+1];
}
if((val1[u*2][1]+val2[u*2+1][0])<val3[u])
{
val3[u]=val1[u*2][1]+val2[u*2+1][0];
id3[u]=id2[u*2+1][0];
}
if((val1[u*2+1][0]+val2[u*2][1])<val3[u])
{
val3[u]=val1[u*2+1][0]+val2[u*2][1];
id3[u]=id2[u*2][1];
}
}
signed main()
{
n=Read(),p=Read();
for(int i=1;i<=n;i++) a[i]=Read();
for(int i=1;i<=n*4;i++) val1[i][0]=val1[i][1]=val2[i][0]=val2[i][1]=val3[i]=inf;
update1(1,1,n,1,a[1]+p,0),update1(1,1,n,1,a[1]-p,1);
for(int i=2;i<=n;i++) update2(1,1,n,i,a[i]+p*i,0),update2(1,1,n,i,a[i]-p*i,1);
update3(1,1,n,1);
for(int i=1;i<n;i++)
{
ans+=val3[1];
int v=id3[1];
update1(1,1,n,v,a[v]+p*v,0),update1(1,1,n,v,a[v]-v*p,1);
update2(1,1,n,v,inf,0),update2(1,1,n,v,inf,1);
update3(1,1,n,v);
}
cout<<ans;
return 0;
}