题意:
给定一个树形的游戏网络,可以从根节点出发K次,每次沿着一条路径走下去,不能回头,出口在各个叶子节点,在路过一个节点时可以 获得该点的权值,每个点的权值只能被获得一次(获得一次之后该点的权值变为0),问K次怎样走最后可以获得的权值总和最大,求最大值。
解析:
首先预处理出dfs序,然后求出每个dfs序的权值,以及每个节点在这个dfs序中的子树的区间。
然后用线段树维护最大的链上,每个节点的dfs序区间。
区间更新,减去当前子节点的权值。
AC代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define ls (o<<1)
#define rs (o<<1|1)
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const int MAXN = 100005;
//用于记录每个节点dfs序的区间
struct Node {
int L, R;
}node[MAXN];
//------------------------------------------
int pos[MAXN*2]; //记录欧拉序
ll sumv[MAXN*2]; //路径权值之和
ll power[MAXN]; //权值
int fa[MAXN];
int total; //记录位置
vector<int> edge[MAXN];
void init(int n) {
for(int i = 0; i <= n; i++) {
edge[i].clear();
fa[i] = 0;
}
}
void addEdge(int u, int v) {
edge[u].push_back(v);
}
void dfs(int u, int pre, ll sum) {
node[u].L = ++total;
pos[total] = u;
sumv[total] = sum;
for(int i = 0; i < edge[u].size(); i++) {
int v = edge[u][i];
if(v == pre) continue;
dfs(v, u, sum+power[v]);
}
node[u].R = total;
}
//------------------------------------------
//------------------------------------------
ll maxv[MAXN<<2]; //线段树维护最大权值
ll addv[MAXN<<2]; //区间更新的懒惰标记
ll _maxp, _maxv;
void pushDown(int o) {
if(addv[o]) {
addv[ls] += addv[o]; maxv[ls] += addv[o];
addv[rs] += addv[o]; maxv[rs] += addv[o];
addv[o] = 0;
}
}
void pushUp(int o) {
maxv[o] = max(maxv[ls], maxv[rs]);
}
void build(int o, int L, int R) {
addv[o] = 0;
if(L == R) {
maxv[o] = sumv[L];
return ;
}
int M = (L+R)/2;
build(ls, L, M);
build(rs, M+1, R);
pushUp(o);
}
int ql, qr;
ll val;
void modify(int o, int L, int R) {
if(ql <= L && R <= qr) {
addv[o] += val;
maxv[o] += val;
return ;
}
int M = (L+R)/2;
pushDown(o);
if(ql <= M) modify(ls, L, M);
if(qr > M) modify(rs, M+1, R);
pushUp(o);
}
void query(int o, int L, int R) {
if(L == R) {
_maxv = maxv[o];
_maxp = pos[L];
return ;
}
pushDown(o);
int M = (L+R)/2;
if(maxv[ls] > maxv[rs]) query(ls, L, M);
else query(rs, M+1, R);
}
//------------------------------------------
int n, m;
int main() {
int T, cas = 1;
scanf("%d", &T);
while(T--) {
scanf("%d%d", &n,&m);
init(n);
for(int i = 1; i <= n; i++)
scanf("%lld", &power[i]);
int u, v;
for(int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
addEdge(u, v);
fa[v] = u;
}
total = 0;
dfs(1, -1, power[1]);
build(1, 1, n);
ll ans = 0, cur;
while(m--) {
_maxv = -INF;
ql = 1, qr = n;
query(1, 1, n);
ans += _maxv; cur = _maxp;
while(cur != 0) {
if(sumv[cur] <= 0) break;
ql = node[cur].L, qr = node[cur].R;
val = -power[cur];
modify(1, 1, n);
sumv[cur] = 0;
cur = fa[cur];
}
}
printf("Case #%d: %lld\n", cas++, ans);
}
return 0;
}