题目传送门
题意: 有一颗树,你在根节点1,从1出发经过每个点恰好一次,经过每条边的时间为1,每个点有一个权值a[i],每个点在第一次经过的时候就开始计时,a[i]秒之后结束。特殊的:点1的计时是最后开始。问:你最短需要多长时间让所有点都计时结束。
思路: f [ x ] f[x] f[x]表示以x为根节点的子树需要的最短时间。假设y是x的一颗子树,那么就有:
- f [ x ] = a [ x ] f[x]=a[x] f[x]=a[x]
- f [ x ] = m a x ( f [ x ] , f [ y ] + s i z [ x ] y + 1 ) f[x]=max(f[x],f[y]+siz[x]_y+1) f[x]=max(f[x],f[y]+siz[x]y+1)(1表示从x到y)
s i z [ x ] y siz[x]_y siz[x]y表示走到y之前,在x的子树内,走路花费了的时间。
但是这样的话,就需要考虑一个子树顺序问题,因为顺序不同可能导致更新出来的 f [ x ] f[x] f[x]不同。
我们假设有 y y y, z z z两颗子树,先遍历 y y y子树时:
- f [ x ] = m a x ( f [ x ] , s i z [ x ] y + m a x ( f [ y ] , f [ z ] + s i z [ y ] + 2 ) + 1 ) f[x]=max(f[x],siz[x]_y+ max(f[y],f[z]+siz[y]+2)+1) f[x]=max(f[x],siz[x]y+max(f[y],f[z]+siz[y]+2)+1)
如果先遍历 z z z子树:
- f [ x ] = m a x ( f [ x ] , s i z [ x ] y + m a x ( f [ z ] , f [ y ] + s i z [ z ] + 2 ) + 1 ) f[x]=max(f[x],siz[x]_y+max(f[z],f[y]+siz[z]+2)+1) f[x]=max(f[x],siz[x]y+max(f[z],f[y]+siz[z]+2)+1)
那么如果交换比不交换更优,则有:
- m a x ( f [ y ] , f [ z ] + s i z [ y ] ) > m a x ( f [ z ] , f [ y ] + s i z [ z ] ) max(f[y],f[z]+siz[y])>max(f[z],f[y]+siz[z]) max(f[y],f[z]+siz[y])>max(f[z],f[y]+siz[z])
显然,
f
[
y
]
<
f
[
y
]
+
s
i
z
[
z
]
,
f
[
z
]
<
f
[
z
]
+
s
i
z
[
y
]
f[y]<f[y]+siz[z],f[z]<f[z]+siz[y]
f[y]<f[y]+siz[z],f[z]<f[z]+siz[y]
即有,
f
[
z
]
+
s
i
z
[
y
]
>
f
[
y
]
+
s
i
z
[
z
]
f[z]+siz[y]>f[y]+siz[z]
f[z]+siz[y]>f[y]+siz[z]
即,
s
i
z
[
z
]
−
f
[
z
]
<
s
i
z
[
y
]
−
f
[
y
]
siz[z]-f[z]<siz[y]-f[y]
siz[z]−f[z]<siz[y]−f[y]
将这个作为依据对子树排序,再对根节点进行更新即可。
(注意处理子树根节点是1的情况)
注意: 考虑状态转移方程中,为何是 f [ x ] = m a x ( f [ x ] , f [ y ] + s i z [ x ] y + 1 ) f[x]=max(f[x],f[y]+siz[x]_y+1) f[x]=max(f[x],f[y]+siz[x]y+1)而不是 f [ x ] = m a x ( f [ x ] , f [ y ] + s i z [ x ] y + 2 ) f[x]=max(f[x],f[y]+siz[x]_y+2) f[x]=max(f[x],f[y]+siz[x]y+2) ?我们走完之前的子树,现在要走y子树,按道理应该是 s i z [ x ] y + f [ y ] + 2 siz[x]_y+f[y]+2 siz[x]y+f[y]+2,(2表示从x到y再从y到x)。其实,我们知道,f[y]>siz[y](无特殊情况),即在y子树内的走路时间一定是小于在y子树计时总时间的,那么从y走到x的时间其实是已经被包含在 f [ y ] f[y] f[y]里面了。
代码:
#include<bits/stdc++.h>
#define endl '\n'
#define mp make_pair
#define pb push_back
#define ll long long
#define int long long
#define pii pair<int,int>
#define sz(x) (int)(x).size()
#define all(x) (x).begin(),(x).end()
#define mem(a,b) memset(a,b,sizeof(a))
char *fs,*ft,buf[1<<20];
#define gc() (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<20,stdin),fs==ft))?0:*fs++;
inline int read()
{
int x=0,f=1;
char ch=gc();
while(ch<'0'||ch>'9')
{
if(ch=='-')
f=-1;
ch=gc();
}
while(ch>='0'&&ch<='9')
{
x=x*10+ch-'0';
ch=gc();
}
return x*f;
}
using namespace std;
const int N=5e5+10;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
const double eps=1e-7;
vector<int>e[N];
int a[N],f[N],siz[N],temp[N];
bool cmp(int x,int y)
{
return siz[x]-f[x]<siz[y]-f[y];
}
void dfs(int fa,int x)
{
if(x!=1)
f[x] = a[x];
for(auto i:e[x])
if(i!=fa)
dfs(x,i);
int tot=0;
for(auto i:e[x])
if(i!=fa)
temp[++tot] = i;
sort(temp+1,temp+tot+1,cmp);
for(int i=1;i<=tot;i++)
f[x]=max(f[x],f[temp[i]]+siz[x]+1),siz[x]+=siz[temp[i]]+2;
if(x==1)
f[x] = max(f[x],siz[x]+a[x]);
}
void solve()
{
int n;
cin>>n;
for(int i=1;i<=n;i++)
cin>>a[i];
for(int i=1;i<=n-1;i++)
{
int u,v;
cin>>u>>v;
e[u].pb(v);
e[v].pb(u);
}
dfs(1,1);
cout<<f[1]<<endl;
}
signed main()
{
solve();
return 0;
}