说明
本篇题解内容未涉及到树链剖分、矩阵乘法(可能有其思想),请放心食用。
题意描述
给定一棵节点数量为n树,每个节点被标记的代价
,m次询问,对于每个询问给出s,t,求满足s,t被标记,且每个被标记的点距离最近的被标记的点的距离小于k
,标记点的最小代价。
基本思路
为了方便叙述,我们首先定义一下一个状态:表示节点u距离子节点中已经标记的节点最小距离小于等于x。
原问题即求同时有的最小代价。
用到倍增思想,表示从满足
,不计u子树的最小代价,其中
表示u的
级祖先。
边界条件:对于v是u的父节点
- 当
,
;
- 当
,
;
- 当
,
,
,由于当距离u,v到所有被标记节点的距离都大于等于2时,必须标记周围的一个节点,所以
v及与v相连的节点的最小权值。
状态转移:对于,枚举中间节点即
状态,把上下两个结果相加,即
然后对于每个询问,先分别求出满足、
的最小代价、 满足
、
的最小代价。
对于满足的最小代价,设
表示满足
的最小代价,从大到小枚举i,如果
的深度大于等于
的深度,有算出
的意义,将其设为上一个有意义的点与
合并,即
,同时把u赋值为
。实现时把
的第一维优化掉。
然后把两个求出的结果进行合并。
对于与
的合并
- 首先加上
- 当
,被多计算,一次,减去一个
。
- 当
,
不满足要求,应再标记一个周围的节点,则加上
及与
相连的节点的最小权值。
结果就是所有情况合并的最小值。
时间复杂度,空间复杂度
示例代码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=2e5+5,M=N*2,K=20;
int n,k,q;
int la[N],ne[M],en[M],idx;
LL w[N],mw[N];
int fa[N][K],dep[N];
LL dis[N][K][3][3];
LL d1[3],d2[3];
void gm(LL &x,LL y)
{
if(x>y)x=y;
}
void add(int a,int b)
{
idx++;
ne[idx]=la[a];
la[a]=idx;
en[idx]=b;
}
void dfs(int u)
{
for(int j=la[u];j;j=ne[j])
{
int v=en[j];
if(!dep[v])
{
dep[v]=dep[u]+1;
fa[v][0]=u;
for(int i=1;i<K;i++)fa[v][i]=fa[fa[v][i-1]][i-1];
if(k==1)
{
dis[v][0][0][0]=w[u];
}
if(k==2)
{
dis[v][0][0][0]=dis[v][0][1][0]=w[u];
dis[v][0][0][1]=0;
}
if(k==3)
{
dis[v][0][0][0]=dis[v][0][1][0]=dis[v][0][2][0]=w[u];
dis[v][0][0][1]=dis[v][0][1][2]=0;
dis[v][0][2][2]=mw[v];
}
for(int i=1;i<K;i++)
for(int x=0;x<k;x++)
for(int y=0;y<k;y++)
for(int z=0;z<k;z++)
gm(dis[v][i][x][z],dis[v][i-1][x][y]+dis[fa[v][i-1]][i-1][y][z]);
dfs(v);
}
}
}
int lca(int u,int v)
{
if(dep[u]>dep[v])swap(u,v);
for(int i=K-1;i>=0;i--)
if(dep[fa[v][i]]>=dep[u])
v=fa[v][i];
if(u==v)return u;
for(int i=K-1;i>=0;i--)
if(fa[u][i]!=fa[v][i])
{
u=fa[u][i];
v=fa[v][i];
}
return fa[u][0];
}
void get(int u,int x,LL d[])
{
LL p[3];
for(int i=0;i<k;i++)d[i]=w[u];
for(int i=K-1;i>=0;i--)
if(dep[fa[u][i]]>=dep[x])
{
for(int i=0;i<k;i++)
{
p[i]=d[i];
d[i]=1e18;
}
for(int y=0;y<k;y++)
for(int z=0;z<k;z++)
gm(d[z],p[y]+dis[u][i][y][z]);
u=fa[u][i];
}
}
LL query(int u,int v)
{
int p=lca(u,v);
get(u,p,d1),get(v,p,d2);
LL res=d1[0]+d2[0]-w[p];
for(int x=0;x<k;x++)
for(int y=0;y<k;y++)
gm(res,d1[x]+d2[y]+(x+y>k)*mw[p]);
return res;
}
int main()
{
scanf("%d%d%d",&n,&q,&k);
for(int i=1;i<=n;i++)
{
scanf("%lld",&w[i]);
mw[i]=w[i];
}
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
add(a,b);
add(b,a);
gm(mw[a],w[b]);
gm(mw[b],w[a]);
}
dep[1]=1;
memset(dis,0x3f,sizeof dis);
dfs(1);
while(q--)
{
int a,b;
scanf("%d%d",&a,&b);
printf("%lld\n",query(a,b));
}
return 0;
}