LOJ题目传送门
题目描述
在逃亡者的面前有一个迷宫,这个迷宫由 n个房间和 n−1 条双向走廊构成,每条走廊会链接不同的两个房间,所有的房间都可以通过走廊互相到达。换句话说,这是一棵树。
逃亡者会选择一个房间进入迷宫,走过若干条走廊并走出迷宫,但他永远不会走重复的走廊。
在第 i个房间里,有 Fi个铁球,每当一个人经过这个房间时,他就会受到铁球的阻挡。逃亡者手里有 V 个磁铁,当他到达一个房间时,他可以选择丢下一个磁铁(也可以不丢),将与这个房间相邻的所有房间里的铁球吸引到这个房间。这个过程如下:
逃亡者进入房间。
逃亡者丢下磁铁。
逃亡者走出房间。
铁球被吸引到这个房间。
注意逃亡者只会受到这个房间原有的铁球的阻拦,而不会受到被吸引的铁球的阻挡。
在逃亡者走出迷宫后,追逐者将会沿着逃亡者走过的路径穿过迷宫,他会碰到这条路径上所有的铁球。
请帮助逃亡者选择一条路径,使得追逐者遇到的铁球数量减去逃亡者遇到的铁球数量最大化。
输入输出格式
输入格式:
第一行两个空格隔开的整数整数 n 和 V 。 第二行 n个空格隔开的整数表示 Fi 。 之后的 n−1 行,每行两个空格隔开的整数 x 和 y ,表示有一条走廊连接编号为 x 和编号为 y 的房间。
输出格式:
输出一个整数表示最优情况下追逐者遇到的铁球数量减去逃亡者遇到的铁球数量。
输入输出样例
输入样例#1:
12 2
2 3 3 8 1 5 6 7 8 3 5 4
2 1
2 7
3 4
4 7
7 6
5 6
6 8
6 9
7 10
10 11
10 12
输出样例#1:
36
说明
样例解释
有一个最优方案如下:
从6号房间进入迷宫并丢下第一个磁铁,他遇到了5个铁球,这个时候6号房间会有27个铁球,而5号,7号,8号,9号房间都没有铁球。
走到7号房间丢下第二个磁铁并走出迷宫,他遇到了0个铁球,这个时候7号房间会有41个铁球,而2号,4号,6号,10号房间会没有铁球。
在这个过程中,逃亡者会遇到5个铁球而追逐者会遇到 41个铁球。
数据范围
对于
100
100%
100的数据,
有
1
≤
n
≤
1
0
5
;
0
≤
V
≤
100
;
0
≤
F
i
≤
1
0
9
,
1
≤
n
≤
1
0
5
;
。
1≤n≤10^5;0≤V≤100;0≤Fi≤10^9,1\le n\le 10^5;。
1≤n≤105;0≤V≤100;0≤Fi≤109,1≤n≤105;。
首先非常显然的就是,这些房间连成一棵树,而且是一棵无根树。假定点 x x x是根,那么逃亡者的路径就是从 x x x的子树里选一个点,往这个点走。然后重复上述过程,直到他选择一个点离开迷宫。
所以一个显然的树形DP就是枚举每一个点做根,然后以这个根向下DP,求出扔下 v v v个磁铁所能获得的最大铁球差异值。
当然我们需要明确,铁球的差异是如何产生的。当我们在树上 x x x点扔下一个磁铁,那么差异是这个点的所有儿子的权值和(为什么没有它自己和它父亲?因为这些点都是先走的,或者这些点被之前的点吸引了,并不可能增加当前点的贡献)。
所以这个树形DP很简单,然而复杂度是过不去的。只能拿70分。
#pragma GCC optimize(3)
#include<bits/stdc++.h>
#define MAXN 100005
#define ll long long
using namespace std;
char tc(){
static char fl[100000],*A=fl,*B=fl;
return A==B&&(B=(A=fl)+fread(fl,1,100000,stdin),A==B)?EOF:*A++;
}
ll read(){
char c;ll x=0,y=1;while(c=tc(),(c<'0'||c>'9')&&c!='-');
if(c=='-') y=-1;else x=c-'0';while(c=tc(),c>='0'&&c<='9')
x=x*10+c-'0';return x;
}
ll ans,f[MAXN][105];
int n,v,cnt,val[MAXN],head[MAXN<<1],nxt[MAXN<<1],go[MAXN<<1],vis[MAXN];
void add(ll x,ll y){
go[cnt]=y;nxt[cnt]=head[x];head[x]=cnt;cnt++;
go[cnt]=x;nxt[cnt]=head[y];head[y]=cnt;cnt++;
}
ll rise(int x,int v){
if(v==0) return 0;
if(f[x][v]) return f[x][v];
ll res=0,out=0;vis[x]=1;
register int i;
for(i=head[x];i!=-1;i=nxt[i]){
int to=go[i];
if(vis[to]) continue;
res+=val[to];
}
for(i=head[x];i!=-1;i=nxt[i]){
int to=go[i];
if(vis[to]) continue;
out=max(out,res+rise(to,v-1));
}
for(i=head[x];i!=-1;i=nxt[i]){
int to=go[i];
if(vis[to]) continue;
out=max(out,rise(to,v));
}
vis[x]=0;f[x][v]=out;
return f[x][v];
}
int main()
{
n=read();v=read();
register int i,j,k;
memset(head,-1,sizeof(head));
for(i=1;i<=n;i++) val[i]=read();
for(i=1;i<n;i++){
ll x=read(),y=read();
add(x,y);
}
if(n<=1000){ //这个是根据数据特判
for(i=1;i<=n;i++){
for(j=1;j<=n;j++) vis[j]=0;
for(j=1;j<=n;j++)
for(k=1;k<=v;k++) f[j][k]=0;
ans=max(ans,rise(i,v));
}
printf("%lld\n",ans);
return 0;
}
printf("%lld",rise(1,v)); //由于保证有30分从1开始,所以直接以1为根做
return 0;
}
那么正解是什么呢?
既然我们要以每一个点为根,那么我们可以考虑定1为假根,然后用一种类似点分治的思路,去计算它的儿子的答案。(实际上就是以儿子为根的答案),我们设up[x][i]表示从某点出发,扔出i个磁铁,最终到达x的最大差异值,而down[x][i]表示从x出发,扔下i个磁铁所能获得的最大差异值。
那么我们显然可以写出转移:
for(int j=1;j<=v;j++) up[x][j]=max(up[x][j],up[to][j]),down[x][j]=max(down[x][j],down[to][j]);
for(int j=1;j<=v;j++){
up[x][j]=max(up[x][j],up[to][j-1]+deg[x]-val[to]); //up从下往上,所以应该减去下面的点。
down[x][j]=max(down[x][j],down[to][j-1]+deg[x]-val[fa[x]]); //down从上往下,所以减去父亲
}
但是我们要去重,因为如果不去重,可能up和down走的是同一个儿子,那么这样就违反了题目每条边只能走一次的规定。于是我们在循环的时候先更新答案,再更新up和down,代码如下:
for(int i=head[x];i!=-1;i=nxt[i]){
int to=go[i];
if(fa[x]==to) continue;sta[++top]=to; //sta这句话和目前无关,可以无视
for(int j=1;j<v;j++) ans=max(ans,up[x][j]+down[to][v-j]); //先更新,就能保证不会重复
for(int j=1;j<=v;j++) up[x][j]=max(up[x][j],up[to][j]),down[x][j]=max(down[x][j],down[to][j]);
for(int j=1;j<=v;j++){
up[x][j]=max(up[x][j],up[to][j-1]+deg[x]-val[to]);
down[x][j]=max(down[x][j],down[to][j-1]+deg[x]-val[fa[x]]);
}
}
ans=max(ans,max(down[x][v],up[x][v]));
但是我们还得考虑一个问题,就是比如我们从 u u u走上 x x x,再从 x x x走下 v v v,这样的答案和从 v v v走上 x x x,再从 x x x走下 u u u的大答案是不一样的。而我们上面的统计由于为了去重,就会导致我们匹配up和down时up走的子树肯定在down左边。但如果在右边怎么办?我们把 x x x的子树顺序前后翻转一下,就可以使匹配时up走的子树在down右边。这样操作就可以统计所有的答案。
我们在代码里用一个栈来操作。上文的sta就是将所有儿子放在栈里。最后统计右边的时候就倒着枚举栈即可。
while(top){
int to=sta[top--];
for(int j=1;j<v;j++) ans=max(ans,up[x][j]+down[to][v-j]);
for(int j=1;j<=v;j++) up[x][j]=max(up[x][j],up[to][j]),down[x][j]=max(down[x][j],down[to][j]);
for(int j=1;j<=v;j++){
up[x][j]=max(up[x][j],up[to][j-1]+deg[x]-val[to]);
down[x][j]=max(down[x][j],down[to][j-1]+deg[x]-val[fa[x]]);
}
}
ans=max(ans,max(down[x][v],up[x][v]));
下面是完整的代码:
#include<bits/stdc++.h>
#define MAXN 100005
#define ll long long
using namespace std;
int read(){
char c;int x;while(c=getchar(),c<'0'||c>'9');x=c-'0';
while(c=getchar(),c>='0'&&c<='9') x=x*10+c-'0';return x;
}
int n,v,cnt,val[MAXN],head[MAXN<<1],nxt[MAXN<<1],go[MAXN<<1],fa[MAXN],sta[MAXN],top;
ll up[MAXN][105],down[MAXN][105],ans,deg[MAXN];
void add(int x,int y){
go[cnt]=y;nxt[cnt]=head[x];head[x]=cnt;cnt++;
go[cnt]=x;nxt[cnt]=head[y];head[y]=cnt;cnt++;
}
void dfs(int x){
for(int i=1;i<=v;i++) up[x][i]=deg[x],down[x][i]=deg[x]-val[fa[x]];
for(int i=head[x];i!=-1;i=nxt[i]){
int to=go[i];
if(fa[x]==to) continue;
fa[to]=x;dfs(to);
}
for(int i=head[x];i!=-1;i=nxt[i]){
int to=go[i];
if(fa[x]==to) continue;sta[++top]=to;
for(int j=1;j<v;j++) ans=max(ans,up[x][j]+down[to][v-j]);
for(int j=1;j<=v;j++) up[x][j]=max(up[x][j],up[to][j]),down[x][j]=max(down[x][j],down[to][j]);
for(int j=1;j<=v;j++){
up[x][j]=max(up[x][j],up[to][j-1]+deg[x]-val[to]);
down[x][j]=max(down[x][j],down[to][j-1]+deg[x]-val[fa[x]]);
}
}
ans=max(ans,max(down[x][v],up[x][v]));
for(int i=1;i<=v;i++) up[x][i]=deg[x],down[x][i]=deg[x]-val[fa[x]];
while(top){
int to=sta[top--];
for(int j=1;j<v;j++) ans=max(ans,up[x][j]+down[to][v-j]);
for(int j=1;j<=v;j++) up[x][j]=max(up[x][j],up[to][j]),down[x][j]=max(down[x][j],down[to][j]);
for(int j=1;j<=v;j++){
up[x][j]=max(up[x][j],up[to][j-1]+deg[x]-val[to]);
down[x][j]=max(down[x][j],down[to][j-1]+deg[x]-val[fa[x]]);
}
}
ans=max(ans,max(down[x][v],up[x][v]));
}
int main()
{
n=read();v=read();
memset(head,-1,sizeof(head));
for(int i=1;i<=n;i++) val[i]=read();
for(int i=1;i<n;i++){
int x=read(),y=read();
add(x,y);deg[x]+=val[y];deg[y]+=val[x];
}
dfs(1);
printf("%lld",ans);
return 0;
}