考场上能想出一道树形DP,感觉很爽,本来OI时期树形DP只会点树形背包什么的,这里完全是临时脑补出一种模型,美滋滋。
一棵树上,从任意一个点到任意另外一个点利润为价格差-路费,即w[v]-w[u]-dis[u][v]最大。于是对于一个子树的根节点,假设它的子树们内部的情况已经跟新完ans了,于是不同子树之间的情况,吧w[v]-w[u]-dis[u][v]拆成,w[v]-dis[v][lca]和w[u]+dis[u][lca]的差值,要使差值最大,即前面部分最大,后面部分最小。
由于最大最小可能来自同一棵子树,于是设立second的最大和最小,并通过子树的编号判断是否是来自同一棵子树,不然就用second来处理。
具体看代码。
做了这道题后突然想起很多年前树的最长链也可以这样做,记录一个最大和不严格次大就行了。
#include<cstdio>
#include<cstring>
#include<vector>
#define maxl 100010
#define inf 2000000001
using namespace std;
int n,cnt,ans;
int ehead[maxl],w[maxl];
struct ed{int nxt,to,l;} e[maxl<<1];
struct node
{
int maxval,maxvalind,secmaxval,secmaxvalind;
int minval,minvalind,secminval,secminvalind;
int fal;
}a[maxl];
vector <int> e2[maxl];
bool vis[maxl];
inline int read()
{
int x=0;char ch=getchar();
while(ch<'0' || ch>'9') ch=getchar();
while(ch>='0'&& ch<='9') x=x*10+ch-'0',ch=getchar();
return x;
}
void dfs(int u)
{
int v;vis[u]=true;
for(int i=ehead[u];i;i=e[i].nxt)
{
v=e[i].to;
if(!vis[v])
{
e2[u].push_back(v);
a[v].fal=e[i].l;
dfs(v);
}
}
}
void prework()
{
memset(ehead,0,sizeof(ehead));
memset(vis,false,sizeof(vis));
n=read();
for(int i=1;i<=n;i++)
w[i]=read();
int u,v,l;
cnt=0;
for(int i=1;i<=n-1;i++)
{
u=read();v=read();l=read();
e[++cnt].to=v;e[cnt].nxt=ehead[u];e[cnt].l=l;ehead[u]=cnt;
e[++cnt].to=u;e[cnt].nxt=ehead[v];e[cnt].l=l;ehead[v]=cnt;
}
for(int i=1;i<=n;i++)
{
e2[i].clear();
a[i].maxval=-inf;a[i].secmaxval=-inf;
a[i].minval=inf;a[i].secminval=inf;
}
a[1].fal=0;
dfs(1);
}
void dp(int u)
{
int maval,mival,v,l=e2[u].size(),d;
if(!l)
{
a[u].maxval=w[u];a[u].maxvalind=u;
a[u].minval=w[u];a[u].minvalind=u;
return;
}
a[u].maxval=w[u];a[u].maxvalind=u;
a[u].minval=w[u];a[u].minvalind=u;
for(int i=0;i<l;i++)
{
v=e2[u][i];
dp(v);
maval=a[v].maxval-a[v].fal;
if(maval>a[u].maxval)
{
a[u].secmaxval=a[u].maxval;a[u].secmaxvalind=a[u].maxvalind;
a[u].maxval=maval;a[u].maxvalind=v;
}
else
if(maval>a[u].secmaxval)
a[u].secmaxval=maval,a[u].secmaxvalind=v;
mival=a[v].minval+a[v].fal;
if(mival<a[u].minval)
{
a[u].secminval=a[u].minval;a[u].secminvalind=a[u].minvalind;
a[u].minval=mival;a[u].minvalind=v;
}
else
if(mival<a[u].secminval)
a[u].secminval=mival,a[u].secminvalind=v;
}
for(int i=0;i<l;i++)
{
v=e2[u][i];
if(a[u].maxvalind==v)
d=a[u].secmaxval-(a[v].minval+a[v].fal);
else
d=a[u].maxval-(a[v].minval+a[v].fal);
if(d>ans)
ans=d;
if(a[u].minvalind==v)
d=a[v].maxval-a[v].fal-a[u].secminval;
else
d=a[v].maxval-a[v].fal-a[u].minval;
if(d>ans)
ans=d;
}
}
void mainwork()
{
ans=0;
dp(1);
}
void print()
{
printf("%d\n",ans);
}
int main()
{
int t;
t=read();
for(int i=1;i<=t;i++)
{
prework();
mainwork();
print();
}
return 0;
}