题目链接
题意:给定一个
n
n
n(
n
≤
1
e
5
n\leq1e5
n≤1e5)节点,以1为节点的树。每个节点都有价值为w[i]的蝴蝶。从点1开始抓蝴蝶,每1秒可以移动相邻的1个点。
每次到达1个点后,与其相邻的点上的蝴蝶就会手到惊扰,在 t [ u ] ≤ 3 t[u]\leq3 t[u]≤3时间内会飞走,问能抓到的蝴蝶的最大价值是多少
题解:
我们分析一下时间,
t
[
u
]
≤
3
t[u]\leq3
t[u]≤3,也就是我可以从一个点到达一个点,然后不往下走,到达另一个
t
[
u
]
≤
3
t[u]\leq3
t[u]≤3的点,如下所示,假设
t
[
4
]
=
3
t[4]=3
t[4]=3,我们可以选择下面的路线,对于时间小于3的点,不能这样。
这样我们发现,我们有两种转移方式:
1:一直往下走,没有上述的转弯。
2:其中某几层有上述的转弯。
这样,我们假设 f [ u ] f[u] f[u]表示到达某个点之前蝴蝶已经飞走,但是其他的蝴蝶未被惊动,sum[u]表示节点u的子节点 f f f的集合。
这样我们就可以写出转移方程:
1:
f
[
u
]
=
m
a
x
(
s
u
m
[
u
]
+
a
[
v
]
)
f[u]=max(sum[u]+a[v])
f[u]=max(sum[u]+a[v])(v是u的子节点),当我们到达u之后,其子节点都被惊扰了,所以都会飞走,这样他们飞走之后就变成了一个一个的
f
[
v
]
f[v]
f[v],我们肯定不会让他们都飞走,所以我们选择一个最大的,加上就行。
2:我们先到一个普通节点,然后再回到另一个
t
[
u
]
=
3
t[u]=3
t[u]=3的节点,这样才有意义。因为我最多只能选择两个节点,这是不确定的,所以这暴力即可。
可以将所有的 t [ u ] = 3 t[u]=3 t[u]=3的节点放到一个set里面,然后暴力u的所有节点,然后返回到价值最大的节点即可,但是要注意的是我们第一次到达的那个节点因为我们还要回去,所以不能用f[v]来表示,应该是 s u m [ u ] − f [ v ] + w [ v ] + s u m [ v ] + ∗ s t . r b e g i n ( ) sum[u]-f[v]+w[v]+sum[v]+*st.rbegin() sum[u]−f[v]+w[v]+sum[v]+∗st.rbegin()。
这里可能会疑惑为什么我们选择set里最大的点,这里是因为我们不管选择哪个点,所有的点都会变成f[v]的情况,只不过这里有一个先走,后走的问题。我选谁,肯定是要先走谁。所以这个题其实每个子树就像是一个独立的,走完一个再走一个。
下面是AC代码:
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdio>
#include<set>
#include<vector>
using namespace std;
#define int long long
const int N=1e6+10;
int sum[N],w[N],t[N],f[N];
vector<int> vec[N];
void dfs(int u,int fa)
{
int mx=0;
multiset<int> se;
for(int v:vec[u])
{
if(v==fa) continue;
dfs(v,u);
sum[u]+=f[v];//求sum
mx=max(mx,w[v]);
if(t[v]==3) se.insert(w[v]);//将t=3的点全部放进来
}
f[u]=sum[u]+mx;
se.insert(-0x3f3f3f3f);//因为下面要用到se.rbegin(),如果里面没有元素就会rte,所以我们要给他设置一个值,假设没有这种情况的话是最小值。
for(int v:vec[u])
{
if(v==fa) continue;
if(t[v]==3) se.erase(se.find(w[v]));//删去那个点
f[u]=max(f[u],sum[u]-f[v]+w[v]+sum[v]+*se.rbegin());//计算最大值
if(t[v]==3) se.insert(w[v]);//记得填回去
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
multiset<int> se;
int Case;
cin>>Case;
while(Case--)
{
int n;
cin>>n;
for(int i=1;i<=n;i++) f[i]=sum[i]=0,vec[i].clear();
for(int i=1;i<=n;i++) cin>>w[i];
for(int i=1;i<=n;i++) cin>>t[i];
for(int i=1;i<=n-1;++i)
{
int u,v;
cin>>u>>v;
vec[u].push_back(v);
vec[v].push_back(u);
}
dfs(1,0);
cout<<f[1]+w[1]<<endl;
}
return 0;
}