其实这件事情告诉我们:万物皆可矩阵。
同样的,树形dp做一次O(n),一共n次,所以为n^2的复杂度。
这样有很多分呢qwq!
但是还要更好。
一般来说,当我们找到一个复杂度接近正确(个p)的方法时,先考虑怎么优化。
我们不应该每次都重新做一遍,而是只管那些被影响了的部分。
设被强制的两个点为x,y。
我们发现,在树形dp的基础上,它们会影响的答案只包括他们以及他们俩到根节点路上的点。
而那些其它的子树都没变。所以我们要想想怎么记录下这些答案并合并。
正如上次说的:一个线性变换首先考虑矩阵。而本题这种弱智树形dp更应该如此。
设dp[v][0]表示v不放军队,dp[v][1]表示放军队。假设我们已知这两个东西。我们来考虑怎么得到u的答案。
强行套矩阵我们发现转移矩阵为
重载矩阵运算为取对应项和的最小值。并且对于-1特判因为这个代表不参与运算。
dp‘[u]表示u结点去除当前v结点后剩下的答案。这样就可以得到那些不被影响的子树。
然后我们可以简单地发现这个也满足结合律。合起来取min嘛,,我也不会证明。
所以我们可以先把这个转移矩阵预处理出一个倍增数组。到时候从最下面开始往上走就行了。
对于一个询问的x,y两个端点。按照LCA的方式去跳,路上乘转移矩阵,最后再从LCA跳到1号节点就能得到dp值。
关于重载add和min注意看代码。而且为了运算安全我随手加了一个ans结构体。
#include<bits/stdc++.h>
using namespace std;
#define in read()
#define int long long
int in{
int cnt=0,f=1;char ch=0;
while(!isdigit(ch)){
ch=getchar();if(ch=='-')f=-1;
}
while(isdigit(ch)){
cnt=cnt*10+ch-48;
ch=getchar();
}return cnt*f;
}
int add(int a,int b){return ((~b)&&(~a))?(a+b):(-1);}
int min(int a,int b){return (((!~b)||(a<b))&&(~a))?a:b;}
struct node{
int a[3][3];
node(int _x=0,int _y=0,int _xx=0,int _yy=0){
a[1][1]=_x;a[1][2]=_y;
a[2][1]=_xx;a[2][2]=_yy;
}
node operator *(const node &b){
node ans;
ans.a[1][1]=min(add(a[1][1],b.a[1][1]),add(a[1][2],b.a[2][1]));
ans.a[1][2]=min(add(a[1][1],b.a[1][2]),add(a[1][2],b.a[2][2]));
ans.a[2][1]=min(add(a[2][1],b.a[1][1]),add(a[2][2],b.a[2][1]));
ans.a[2][2]=min(add(a[2][1],b.a[1][2]),add(a[2][2],b.a[2][2]));
return ans;
}
};
struct bili{
int x,y;
bili(int _x=0,int _y=0){
x=_x;y=_y;
}
bili operator *(const node &b){
return bili(min(add(x,b.a[1][1]),add(y,b.a[2][1])),min(add(x,b.a[1][2]),add(y,b.a[2][2])));
}
};
int dp[100003][2],fa[100003][22],dep[100003];
node mat[100003][22];
int n,m,first[100003],nxt[200003],to[200003],tot;
int p[100003];char ch[5];
void ad(int a,int b){
nxt[++tot]=first[a];first[a]=tot;to[tot]=b;
}
void dfs1(int u,int faa){
dp[u][1]=p[u];dep[u]=dep[faa]+1;
for(int i=1;i<=19;i++)fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=first[u];i;i=nxt[i]){
int v=to[i];if(v==faa)continue;
fa[v][0]=u;dfs1(v,u);
dp[u][0]+=dp[v][1];
dp[u][1]+=min(dp[v][0],dp[v][1]);
}
}
void dfs2(int u,int faa){
for(int i=1;i<=19;i++)mat[u][i]=mat[u][i-1]*mat[fa[u][i-1]][i-1];
for(int i=first[u];i;i=nxt[i]){
int v=to[i];if(v==faa)continue;
mat[v][0]=node(-1,dp[u][1]-min(dp[v][0],dp[v][1]),dp[u][0]-dp[v][1],dp[u][1]-min(dp[v][0],dp[v][1]));
dfs2(v,u);
}
}
int solve(int x,int a,int y,int b){
if(dep[x]<dep[y])swap(x,y),swap(a,b);
bili L=bili(dp[x][0],dp[x][1]),R=bili(dp[y][0],dp[y][1]);
if(a)L.x=-1;else L.y=-1;
if(b)R.x=-1;else R.y=-1;
for(int i=19;i>=0;i--){
if(dep[fa[x][i]]>=dep[y]){
L=L*mat[x][i];x=fa[x][i];
}
}
int st;bili ans;
if(x!=y){
for(int i=19;i>=0;i--){
if(fa[x][i]!=fa[y][i]){
L=L*mat[x][i];R=R*mat[y][i];
x=fa[x][i];y=fa[y][i];
}
}
st=fa[x][0];
ans=bili(add(dp[st][0]-dp[x][1]-dp[y][1],add(L.y,R.y)),add(dp[st][1]-min(dp[x][1],dp[x][0])-min(dp[y][0],dp[y][1]),add(min(L.x,L.y),min(R.x,R.y))));
}else{
st=x;ans=L;if(b)ans.x=-1;else ans.y=-1;
}
for(int i=19;i>=0;i--){
if(dep[fa[st][i]]>=1){
ans=ans*mat[st][i];st=fa[st][i];
}
}return min(ans.x,ans.y);
}
signed main(){
n=in;m=in;scanf("%s",ch);
for(int i=1;i<=n;i++)p[i]=in;
for(int i=1;i<n;i++){int a=in;int b=in;ad(a,b);ad(b,a);}
dfs1(1,0);dfs2(1,0);int aa,xx,bb,yy;
//for(int i=1;i<=n;i++)cout<<dp[i][0]<<" "<<dp[i][1]<<endl;
while(m--){
aa=in;xx=in;bb=in;yy=in;
cout<<solve(aa,xx,bb,yy)<<'\n';
}
return 0;
}