前言:矩阵的结合律
我们先以通常的矩阵乘法举例:
如果把乘法和加法换成其他运算,用op1和op2代替,如果op1对op2满足分配律,那么它也满足矩阵的结合律。
典型的例子是图论中的邻接矩阵,因为加法对min满足分配律:
min
(
a
,
b
)
+
c
=
min
(
a
+
c
,
b
+
c
)
\min(a,b)+c = \min(a+c,b+c)
min(a,b)+c=min(a+c,b+c)
所以
D
i
j
=
min
k
=
1
n
(
min
p
=
1
n
(
A
i
p
+
B
p
k
)
+
C
k
j
)
=
min
k
=
1
n
(
min
p
=
1
n
A
i
p
+
B
p
k
+
C
k
j
)
=
min
p
=
1
n
(
min
k
=
1
n
A
i
p
+
B
p
k
+
C
k
j
)
=
min
p
=
1
n
(
A
i
p
+
min
k
=
1
n
(
B
p
k
+
C
k
j
)
)
\begin{aligned} D_{ij}&=\min_{k=1}^n\left(\min_{p=1}^n(A_{ip}+B_{pk})+C_{kj}\right)\\ &=\min_{k=1}^n\left(\min_{p=1}^nA_{ip}+B_{pk}+C_{kj}\right)\\ &=\min_{p=1}^n\left(\min_{k=1}^nA_{ip}+B_{pk}+C_{kj}\right)\\ &=\min_{p=1}^n\left(A_{ip}+\min_{k=1}^n(B_{pk}+C_{kj})\right) \end{aligned}
Dij=k=1minn(p=1minn(Aip+Bpk)+Ckj)=k=1minn(p=1minnAip+Bpk+Ckj)=p=1minn(k=1minnAip+Bpk+Ckj)=p=1minn(Aip+k=1minn(Bpk+Ckj))
所以也可以使用矩阵加速。
同理,+对max也满足分配律,矩阵的结合律同样成立。
正题:动态DP
查询一棵子树 u u u的最大权独立集,只需要查询线段树上对应 [ d f n [ u ] , d f n [ b o t t o m [ u ] ] ] [dfn[u],dfn[bottom[u]]] [dfn[u],dfn[bottom[u]]]的区间的矩阵,虽然我们乘的始终是轻儿子转移矩阵,但是因为重链末端对应的转移矩阵 [ 0 0 w [ b o t t o m ] − I N F ] \begin{bmatrix}0 & 0\\w[bottom] & -INF\end{bmatrix} [0w[bottom]0−INF]的左半部分其实就是 [ f [ b o t t o m ] [ 0 ] f [ b o t t o m ] [ 1 ] ] \begin{bmatrix}f[bottom][0]\\f[bottom][1]\end{bmatrix} [f[bottom][0]f[bottom][1]],所以查询得到的矩阵的左半部分就是我们想要的 [ f [ u ] [ 0 ] f [ u ] [ 1 ] ] \begin{bmatrix}f[u][0]\\f[u][1]\end{bmatrix} [f[u][0]f[u][1]]
这是树链剖分+线段树的做法。
Code:
#include<bits/stdc++.h>
#define maxn 100005
using namespace std;
const int inf = 0x3f3f3f3f;
//f表示轻儿子得到的f`[0],f`[1],g表示对应的转移矩阵
int n,m,a[maxn],dp[maxn][2],f[maxn][2],fa[maxn],dfn[maxn],ln[maxn],tim,son[maxn],siz[maxn],top[maxn],bot[maxn];
int fir[maxn],nxt[maxn<<1],to[maxn<<1],tot;
struct Mat{
int s[2][2];
Mat(){memset(s,-0x3f,sizeof s);}
void init(int *f){s[0][0]=s[0][1]=f[0],s[1][0]=f[1],s[1][1]=-inf;}
int Max(){return max(s[0][0],s[1][0]);}
Mat operator * (const Mat &B)const{
Mat ret;
ret.s[0][0]=max(s[0][0]+B.s[0][0],s[0][1]+B.s[1][0]);
ret.s[0][1]=max(s[0][0]+B.s[0][1],s[0][1]+B.s[1][1]);
ret.s[1][0]=max(s[1][0]+B.s[0][0],s[1][1]+B.s[1][0]);
ret.s[1][1]=max(s[1][0]+B.s[0][1],s[1][1]+B.s[1][1]);
return ret;
}
}g[maxn<<2];
inline void line(int x,int y){nxt[++tot]=fir[x],fir[x]=tot,to[tot]=y;}
void dfs1(int u,int ff){
fa[u]=ff,siz[u]=1;
for(int i=fir[u],v;i;i=nxt[i]) if((v=to[i])!=ff){
dfs1(v,u),siz[u]+=siz[v];
if(siz[v]>siz[son[u]]) son[u]=v;
}
}
void dfs2(int u,int tp){
top[u]=tp,ln[dfn[u]=++tim]=u;
if(son[u]) dfs2(son[u],tp),bot[u]=bot[son[u]];
else bot[u]=u;
for(int i=fir[u],v;i;i=nxt[i]) if(!dfn[v=to[i]]) dfs2(v,v);
}
void dfs(int u,int ff){
dp[u][1]=a[u];
for(int i=fir[u],v;i;i=nxt[i]) if((v=to[i])!=ff)
dfs(v,u),dp[u][0]+=max(dp[v][0],dp[v][1]),dp[u][1]+=dp[v][0];
f[u][0]=dp[u][0]-max(dp[son[u]][0],dp[son[u]][1]),f[u][1]=dp[u][1]-dp[son[u]][0];
//注意这里的f不能算入重儿子,当然也可以省去dp数组改为用f加上重儿子求完后再循环/dfs去掉重儿子
}
void build(int i,int l,int r){
if(l==r) {g[i].init(f[ln[l]]);return;}
int mid=(l+r)>>1;
build(i<<1,l,mid),build(i<<1|1,mid+1,r);
g[i]=g[i<<1]*g[i<<1|1];
}
void insert(int i,int l,int r,int x){
if(l==r) {g[i].init(f[ln[l]]);return;}
int mid=(l+r)>>1;
if(x<=mid) insert(i<<1,l,mid,x);
else insert(i<<1|1,mid+1,r,x);
g[i]=g[i<<1]*g[i<<1|1];
}
Mat query(int i,int l,int r,int x,int y){
if(x<=l&&r<=y) return g[i];
int mid=(l+r)>>1;
if(y<=mid) return query(i<<1,l,mid,x,y);
if(x>mid) return query(i<<1|1,mid+1,r,x,y);
return query(i<<1,l,mid,x,y)*query(i<<1|1,mid+1,r,x,y);
}
void modify(int u,int w){
f[u][1]+=w-a[u],a[u]=w;
Mat pre,now;
while(1){
pre=query(1,1,n,dfn[top[u]],dfn[bot[u]]);
insert(1,1,n,dfn[u]);
now=query(1,1,n,dfn[top[u]],dfn[bot[u]]);
if(!(u=fa[top[u]])) break;
f[u][0]+=now.Max()-pre.Max(),f[u][1]+=now.s[0][0]-pre.s[0][0];
}
}
int main()
{
int x,y;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1;i<n;i++) scanf("%d%d",&x,&y),line(x,y),line(y,x);
dfs1(1,0),dfs2(1,1),dfs(1,0);
build(1,1,n);
while(m--){
scanf("%d%d",&x,&y);
modify(x,y);
printf("%d\n",query(1,1,n,dfn[1],dfn[bot[1]]).Max());
}
}
说完树链剖分的做法,当然还有对应的LCT做法,思路大同小异,不过是实现方式稍有区别,变为access的时候维护轻儿子的转移矩阵。虽然LCT常数大,但是毕竟复杂度是一个log,实际测试的时候较线段树稍快一点,实现也较为简洁。
Code:
#include<bits/stdc++.h>
#define maxn 100005
using namespace std;
const int inf = 1<<30;
int n,m,val[maxn];
vector<int>G[maxn];
struct Mat{
int s[2][2];
Mat(){s[0][0]=s[0][1]=s[1][0]=s[1][1]=-inf;}
inline int Max(){return max(s[0][0],s[1][0]);}
inline void init(int *f){s[0][0]=s[0][1]=f[0],s[1][0]=f[1],s[1][1]=-inf;}
Mat operator * (const Mat &B){
Mat ret;
for(int k=0;k<2;k++)
for(int i=0;i<2;i++)
for(int j=0;j<2;j++)
ret.s[i][j]=max(ret.s[i][j],s[i][k]+B.s[k][j]);
return ret;
}
};
namespace LCT{
int fa[maxn],ch[maxn][2],f[maxn][2];
Mat g[maxn];
#define il inline
#define pa fa[x]
il bool isr(int x){return ch[pa][0]!=x&&ch[pa][1]!=x;}
il bool isc(int x){return ch[pa][1]==x;}
il void upd(int x){
g[x].init(f[x]),
g[x]=g[ch[x][0]]*g[x]*g[ch[x][1]];
}
il void rot(int x){
int y=fa[x],z=fa[y],c=isc(x);
!isr(y)&&(ch[z][isc(y)]=x);
(ch[y][c]=ch[x][!c])&&(fa[ch[y][c]]=y);
fa[ch[x][!c]=y]=x,fa[x]=z;
upd(y),upd(x);
}
il void splay(int x){
for(;!isr(x);rot(x))
if(!isr(pa)) rot(isc(pa)==isc(x)?pa:x);
}
il void access(int x){
for(int y=0;x;x=fa[y=x]){
splay(x);
f[x][0]+=g[ch[x][1]].Max()-g[y].Max();
f[x][1]+=g[ch[x][1]].s[0][0]-g[y].s[0][0];
ch[x][1]=y,upd(x);
}
}
}
using namespace LCT;
void dfs(int u,int ff){
fa[u]=ff,f[u][1]=val[u];//一开始所有儿子都是轻儿子
for(int v: G[u]) if(v!=ff){
dfs(v,u);
f[u][0]+=max(f[v][0],f[v][1]);
f[u][1]+=f[v][0];
}
g[u].init(f[u]);
}
int main(){
g[0].s[0][0]=g[0].s[1][1]=0;
int x,y;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&val[i]);
for(int i=1;i<n;i++) scanf("%d%d",&x,&y),G[x].push_back(y),G[y].push_back(x);
dfs(1,0);
while(m--){
scanf("%d%d",&x,&y);
access(x);
splay(x);
f[x][1]+=y-val[x],val[x]=y,upd(x);
splay(1),printf("%d\n",g[1].Max());
}
}