题目
思路来源
乱搞ac
题解
对于一个u为根的子树,如果出现了这种路径,这个路径要么是(u,v),要么是(v1,v2)
如果是(u,v)的话,得去v子树里刨一个类型相同的点出来
所以,mp[v][x]表示v为子树根,存在一个点p的类型是x,p到v的路径是没用过的最大代价和,用于辅助转移
因为留某个子树不用,肯定是留一个点为了以后用,然后往上启发式合并
合并转移的话,如果用的话是用哪个,再把v往u上挂, 然后考虑u还给不给以后留
而dp[u]表示只考虑u的子树时的最优解,即最大代价和,用于更新答案
更新答案的话,考虑u下有子树v1、v2、v3,
如果v2和v3各存在一条类型是x的路径,就是用dp[v1]+mp[v2][x]+mp[v3][x]更新答案,
用数字表示dp值,用带圈的数字表示mp值,发现
(1+2+③)+(1+②+3)-(1+2+3)=1+②+③,
也就是分成三部分,前两部分是相同结构的,
v3往u上的mp上挂的时候,维护的是v3为子树根,存在一个点p的类型是x,p到v的路径是没用过的最大代价和,也就是③,再加上u下其他子树的dp值之和,也就是1+2,得到1+2+③
而第三部分为u下所有v子树的dp值之和,记为sum
稍稍化简一下,可以得到1+②+③=(1+②+3)-3+③,也就是代码中mp[u][x.fi]+x.se-dp[v]
此外,注意到u的重儿子向u转移的时候,需要全局加上所有轻儿子的dp值之和,
这是一个全局加,所以可以对子树打标记,再魔改下map上维护的值,使之带上子树标记,
于是,now[u]表示对mp[u]上所有点打的全局子树加标记,now[v]表示对mp[v]的
1. 当启发式合并,由于子树大小,需要交换u和v的时候,
此时应该给v上所有点加上sum-dp[v],
但是由于不操作v,所以要操作u上所有点减去这个值,
除此之外,v本身有标记now[v],u本身有标记now[u]
本着不操作v的原则,操作u使得带上now[v]的子树标,
所以应该先令u上所有点加上now[u],
再令u上所有点减去now[v]+sum-dp[v],再把u上点挂到v上
2. 如果不交换的话,就还是把v往u上挂,
把v往u上挂的时候,v需要加上sum-dp[v],加上now[v],再减去now[u]
用map上的值更新答案的时候,也要加上u的子树标和v的子树标,
所以是代码中的now[u]+now[v]+mp[u][x.fi]+x.se-dp[v]
这个全局加,如果用数据结构维护的话会好理解很多,
但是这个做法尚且跑了1s多(map1.7s、umap1.3s),ds感觉已经t飞了
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<map>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
typedef long long ll;
typedef double db;
typedef pair<int,int> P;
#define fi first
#define se second
#define pb push_back
#define dbg(x) cerr<<(#x)<<":"<<x<<" ";
#define dbg2(x) cerr<<(#x)<<":"<<x<<endl;
#define SZ(a) (int)(a.size())
#define sci(a) scanf("%d",&(a))
#define pt(a) printf("%d",a);
#define pte(a) printf("%d\n",a)
#define ptlle(a) printf("%lld\n",a)
#define debug(...) fprintf(stderr, __VA_ARGS__)
using namespace std;
namespace fastIO
{
static char buf[100000],*h=buf,*d=buf;//缓存开大可减少读入时间,看题目给的空间
#define gc h==d&&(d=(h=buf)+fread(buf,1,100000,stdin),h==d)?EOF:*h++//不能用fread则换成getchar
template<typename T>
inline void read(T&x)
{
int f = 1;x = 0;
register char c(gc);
while(c>'9'||c<'0'){
if(c == '-') f = -1;
c=gc;
}
while(c<='9'&&c>='0')x=(x<<1)+(x<<3)+(c^48),c=gc;
x *= f;
}
template<typename T>
void output(T x)
{
if(x<0){putchar('-');x=~(x-1);}
static int s[20],top=0;
while(x){s[++top]=x%10;x/=10;}
if(!top)s[++top]=0;
while(top)putchar(s[top--]+'0');
}
}
using namespace fastIO;
const int N=2e5+10;
int t,n,c[N],w[N],u,v;
vector<int>e[N];
map<int,ll>mp[N];
ll ans,dp[N],now[N];
void dfs(int u,int fa){
ll sum=0;
for(auto &v:e[u]){
if(v==fa)continue;
dfs(v,u);
sum+=dp[v];
}
ll pre=0;//pre表示对上一次对mp[u]全局加标记,now表示这一次对mp[u]的全局加标记
for(auto &v:e[u]){
if(v==fa)continue;
if(mp[u].size()<mp[v].size()){
mp[u].swap(mp[v]);
for(auto &x:mp[v]){
if(mp[u].count(x.fi))dp[u]=max(dp[u],now[u]+mp[u][x.fi]+now[v]+x.se-dp[v]);
}
pre=now[u];
now[u]=now[v]+sum-dp[v];
for(auto &x:mp[v]){
ll z=x.se+pre-now[u];
if(mp[u].count(x.fi))mp[u][x.fi]=max(mp[u][x.fi],z);
else mp[u][x.fi]=z;
}
}
else{
for(auto &x:mp[v]){
if(mp[u].count(x.fi))dp[u]=max(dp[u],now[u]+mp[u][x.fi]+now[v]+x.se-dp[v]);
}
for(auto &x:mp[v]){
ll z=x.se+sum-dp[v]-now[u]+now[v];
if(mp[u].count(x.fi))mp[u][x.fi]=max(mp[u][x.fi],z);
else mp[u][x.fi]=z;
}
}
mp[v].clear();
}
if(mp[u].count(c[u])){
dp[u]=max(dp[u],now[u]+mp[u][c[u]]+w[u]);
}
dp[u]=max(dp[u],sum);
//printf("u:%d dp:%lld\n",u,dp[u]);
ll z=sum+w[u]-now[u];
if(mp[u].count(c[u]))mp[u][c[u]]=max(mp[u][c[u]],z);
else mp[u][c[u]]=z;
}
int main(){
read(t);
while(t--){
read(n);
for(int i=1;i<=n;++i){
e[i].clear();
dp[i]=now[i]=0;
mp[i].clear();
read(c[i]);
}
for(int i=1;i<=n;++i)read(w[i]);
for(int i=2;i<=n;++i){
read(u);read(v);
e[u].pb(v);
e[v].pb(u);
}
dfs(1,0);
ptlle(dp[1]);
}
return 0;
}
/*
1
7
3 1 1 2 2 2 3
2 4 1 5 4 6 2
1 2
1 3
2 4
2 5
3 6
3 7
*/