题目描述
给定一棵
n
n
n个点的树,点带点权。
有
m
m
m次操作,每次操作给定
x
,
y
x,y
x,y,表示修改点
x
x
x的权值为
y
y
y。
你需要在每次操作之后求出这棵树的最大权独立集的权值大小。
n , m ≤ 1 0 5 n,m≤10^5 n,m≤105
分析:
动态dp模板题。
有点像lct求子树询问的那种题目。
我们设
l
d
p
[
x
]
[
0
/
1
]
ldp[x][0/1]
ldp[x][0/1]表示虚(轻)儿子全部都不选和随意选的dp值。
即
l
d
p
[
x
]
[
0
]
=
∑
y
∈
l
i
g
h
t
s
o
n
[
x
]
d
p
[
x
]
[
0
]
ldp[x][0]=\sum_{y\in lightson[x]} dp[x][0]
ldp[x][0]=∑y∈lightson[x]dp[x][0]
l
d
p
[
x
]
[
1
]
=
∑
y
∈
l
i
g
h
t
s
o
n
[
x
]
m
a
x
(
d
p
[
x
]
[
0
]
,
d
p
[
x
]
[
1
]
)
ldp[x][1]=\sum_{y\in lightson[x]} max(dp[x][0],dp[x][1])
ldp[x][1]=∑y∈lightson[x]max(dp[x][0],dp[x][1])
那么当前节点的dp值等于重儿子的dp值和
l
d
p
ldp
ldp合并,设
h
h
h为
x
x
x的重儿子。
d
p
[
x
]
[
0
]
=
m
a
x
(
d
p
[
h
]
[
0
]
,
d
p
[
h
]
[
1
]
)
+
l
d
p
[
x
]
[
1
]
dp[x][0]=max(dp[h][0],dp[h][1])+ldp[x][1]
dp[x][0]=max(dp[h][0],dp[h][1])+ldp[x][1]
d
p
[
x
]
[
1
]
=
d
p
[
h
]
[
0
]
+
l
d
p
[
x
]
[
0
]
+
a
[
x
]
dp[x][1]=dp[h][0]+ldp[x][0]+a[x]
dp[x][1]=dp[h][0]+ldp[x][0]+a[x]
定义一种矩阵乘为以下形式,
C
i
,
j
=
max
k
=
1
n
A
i
,
k
+
B
k
,
j
C_{i,j}=\max_{k=1}^{n}A_{i,k}+B_{k,j}
Ci,j=maxk=1nAi,k+Bk,j
那么上面的转移式可以写成
[
d
p
[
h
]
[
0
]
d
p
[
h
]
[
1
]
]
\begin{bmatrix} dp[h][0] & dp[h][1] \end{bmatrix}
[dp[h][0]dp[h][1]] ×
[
l
d
p
[
x
]
[
1
]
l
d
p
[
x
]
[
0
]
+
a
[
x
]
l
d
p
[
x
]
[
1
]
−
∞
]
\begin{bmatrix} ldp[x][1] & ldp[x][0]+a[x] \\ ldp[x][1] & -\infty \end{bmatrix}
[ldp[x][1]ldp[x][1]ldp[x][0]+a[x]−∞]
因为这种矩阵是符合结合律的,而lct的一棵splay相当于一条链,每个点自己对应一个矩阵,只要按深度从大到小的顺序结合起来就是可以的,也就是splay合并按右儿子,自己,左儿子的顺序合并。正因为合并有顺序,所以不能makeroot(我不知道可不可以)。
一开始先把树连好,并求出所有点的 l d p ldp ldp。修改直接access这个点,并把它旋到根,修改后进行一次更新操作。再用 [ 0 0 ] \begin{bmatrix} 0 & 0 \end{bmatrix} [00]去乘这棵splay根的矩阵就得到解了。
我们不需要记录每个节点的dp值,因为dp值就是用
[
0
0
]
\begin{bmatrix} 0 & 0 \end{bmatrix}
[00]去乘splay根的矩阵。
复杂度是
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)的。
代码:
// luogu-judger-enable-o2
#include <iostream>
#include <cstdio>
#include <cmath>
const int maxn=1e5+7;
const int inf=0x3f3f3f3f;
using namespace std;
int n,m,x,y,cnt;
int ls[maxn],d[2],f[maxn][2];
struct edge{
int y,next;
}g[maxn*2];
struct rec{
int a[2][2];
};
struct node{
int l,r,fa;
int dp[2],x;
rec data;
}t[maxn];
rec operator *(rec a,rec b) //定义的矩阵乘
{
rec c=(rec){{{-inf,-inf},{-inf,-inf}}};
for (int k=0;k<=1;k++)
{
for (int i=0;i<=1;i++)
{
for (int j=0;j<=1;j++)
{
c.a[i][j]=max(c.a[i][j],a.a[i][k]+b.a[k][j]);
}
}
}
return c;
}
void add(int x,int y)
{
g[++cnt]=(edge){y,ls[x]};
ls[x]=cnt;
}
void dfs(int x,int fa)
{
t[x].fa=fa;
for (int i=ls[x];i>0;i=g[i].next)
{
int y=g[i].y;
if (y==fa) continue;
dfs(y,x);
f[x][0]+=max(f[y][0],f[y][1]);
f[x][1]+=f[y][0];
t[x].dp[0]+=f[y][0];
t[x].dp[1]+=max(f[y][0],f[y][1]);
}
f[x][1]+=t[x].x;
}
void updata(int x)
{
int l=t[x].l,r=t[x].r;
t[x].data=(rec){{{t[x].dp[1],t[x].dp[0]+t[x].x},{t[x].dp[1],-inf}}};
if (t[x].r) t[x].data=t[r].data*t[x].data; //按右儿子,自己,左儿子的顺序乘
if (t[x].l) t[x].data=t[x].data*t[l].data;
}
bool isroot(int x)
{
return (x!=t[t[x].fa].l) && (x!=t[t[x].fa].r);
}
void rttr(int x)
{
int y=t[x].l;
t[x].l=t[y].r;
if (t[y].r) t[t[y].r].fa=x;
if (x==t[t[x].fa].l) t[t[x].fa].l=y;
else if (x==t[t[x].fa].r) t[t[x].fa].r=y;
t[y].fa=t[x].fa;
t[x].fa=y;
t[y].r=x;
updata(x); updata(y);
}
void rttl(int x)
{
int y=t[x].r;
t[x].r=t[y].l;
if (t[y].l) t[t[y].l].fa=x;
if (x==t[t[x].fa].l) t[t[x].fa].l=y;
else if (x==t[t[x].fa].r) t[t[x].fa].r=y;
t[y].fa=t[x].fa;
t[x].fa=y;
t[y].l=x;
updata(x); updata(y);
}
void splay(int x)
{
while (!isroot(x))
{
int p=t[x].fa,g=t[p].fa;
if (isroot(p))
{
if (x==t[p].l) rttr(p);
else rttl(p);
}
else
{
if (x==t[p].l)
{
if (p==t[g].l) rttr(p),rttr(g);
else rttr(p),rttl(g);
}
else
{
if (p==t[g].l) rttl(p),rttr(g);
else rttl(p),rttl(g);
}
}
}
}
void access(int x)
{
int y=0;
while (x)
{
splay(x);
d[0]=max(t[t[x].r].data.a[0][0],t[t[x].r].data.a[1][0]); //原来的实儿子变成了虚儿子,修改当前节点的ldp值
d[1]=max(t[t[x].r].data.a[0][1],t[t[x].r].data.a[1][1]);
t[x].dp[0]+=d[0];
t[x].dp[1]+=max(d[0],d[1]);
d[0]=max(t[y].data.a[0][0],t[y].data.a[1][0]); //一个虚儿子变成了实儿子,修改当前ldp值
d[1]=max(t[y].data.a[0][1],t[y].data.a[1][1]);
t[x].dp[0]-=d[0];
t[x].dp[1]-=max(d[0],d[1]);
t[x].r=y;
updata(x);
y=x,x=t[x].fa;
}
}
int change(int x,int k)
{
access(x);
splay(x);
t[x].x=k;
updata(x);
int tmp=max(t[x].data.a[0][0],t[x].data.a[1][0]);
tmp=max(tmp,t[x].data.a[0][1]);
tmp=max(tmp,t[x].data.a[1][1]);
return tmp; //答案其实就是矩阵4个数中的最大值
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
{
scanf("%d",&t[i].x);
updata(i);
}
for (int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs(1,0);
for (int i=1;i<=m;i++)
{
scanf("%d%d",&x,&y);
printf("%d\n",change(x,y));
}
}