luogu P4719 【模板】动态 DP 矩阵乘法_LCT
Code:
// luogu-judger-enable-o2
//Dynamic DP with LCT
#include<bits/stdc++.h>
#define ll long long
#define setIO(s) freopen(s".in","r",stdin)
#define maxn 100002
#define inf 100000000
using namespace std;
//Link cut tree
void de()
{
printf("ok\n");
}
namespace LCT
{
struct Matrix
{
ll a[2][2];
ll*operator[](int x){ return a[x];}
}t[maxn],tmp[maxn];
Matrix operator*(Matrix a,Matrix b)
{
Matrix c;
c[0][0]=max(a[0][0]+b[0][0],a[0][1]+b[1][0]);
c[0][1]=max(a[0][0]+b[0][1],a[0][1]+b[1][1]);
c[1][0]=max(a[1][0]+b[0][0],a[1][1]+b[1][0]);
c[1][1]=max(a[1][0]+b[0][1],a[1][1]+b[1][1]);
return c;
}
//tmp :: 虚儿子信息
//t :: 树剖实际转移矩阵
#define lson ch[x][0]
#define rson ch[x][1]
int ch[maxn][2],f[maxn];
int isRoot(int x)
{
return !(ch[f[x]][0]==x || ch[f[x]][1]==x);
}
int get(int x)
{
return ch[f[x]][1]==x;
}
void pushup(int x)
{
t[x]=tmp[x];
if(lson) t[x]=t[lson]*t[x];
if(rson) t[x]=t[x]*t[rson];
}
void rotate(int x)
{
int old=f[x],fold=f[old],which=get(x);
if(!isRoot(old)) ch[fold][ch[fold][1]==old]=x;
ch[old][which]=ch[x][which^1],f[ch[old][which]]=old;
ch[x][which^1]=old,f[old]=x,f[x]=fold;
pushup(old),pushup(x);
}
void splay(int x)
{
int u=x;
while(!isRoot(u)) u=f[u];
u=f[u];
for(int fa;(fa=f[x])!=u;rotate(x))
if(f[fa]!=u) rotate(get(fa)==get(x)?fa:x);
}
void Access(int x)
{
for(int y=0;x;y=x,x=f[x])
{
splay(x);
if(rson)
{
tmp[x][0][0]+=max(t[rson][0][0],t[rson][1][0]);
tmp[x][1][0]+=t[rson][0][0];
}
if(y)
{
tmp[x][0][0]-=max(t[y][0][0],t[y][1][0]);
tmp[x][1][0]-=t[y][0][0];
}
tmp[x][0][1]=tmp[x][0][0];
rson=y,pushup(x);
}
}
};
//variables
int DP[maxn][2];
int V[maxn],hd[maxn],to[maxn<<1],nex[maxn<<1];
int n,Q,edges;
void add(int u,int v){ nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; }
//build graph
void dfs(int u,int ff)
{
LCT::f[u]=ff;
DP[u][0]=0;
DP[u][1]=V[u];
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff) continue;
dfs(v,u);
DP[u][0]+=max(DP[v][1],DP[v][0]);
DP[u][1]+=DP[v][0];
}
LCT::tmp[u]=(LCT::Matrix){ DP[u][0], DP[u][0], DP[u][1], -inf};
LCT::t[u]=LCT::tmp[u];
}
//主程序~
int main()
{
// setIO("input");
scanf("%d%d",&n,&Q);
for(int i=1;i<=n;++i) scanf("%d",&V[i]);
for(int i=1,u,v;i<n;++i)
{
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs(1,0);
while(Q--)
{
int x,y;
scanf("%d%d",&x,&y);
LCT::Access(x);
LCT::splay(x);
LCT::tmp[x][1][0]+=(ll)y-V[x];
V[x]=y;
LCT::pushup(x);
LCT::splay(1);
printf("%lld\n",max(LCT::t[1][0][0], LCT::t[1][1][0]));
}
return 0;
}