Description
给定一棵有n个节点的树,有m个宝箱和对应的钥匙,它们可能在不同的节点上,也可能在相同的节点上,每个宝箱都有对应的权值(可为负数),现要求在树上选一条简单路径,每到一个节点时,必须先拿走该节点所有的钥匙,然后开启该节点所有能开启的宝箱,求能得到的最大权值和
Input
第一行一整数T表示用例组数,每组用例首先输入两个整数n和m分别表示点数和宝箱数,之后n-1行每行两个整数u和v表示u和v之间有一条边,最后m行每行三个整数a,b,v,表示该宝箱钥匙在a点,宝箱在b点,宝箱权值为v
(T<=30,n,m<=10^5,-1000<=v<=1000)
Output
对于每组用例,输出能够得到的最大权值和
Sample Input
2
4 2
1 2
2 3
3 4
1 1 100
2 4 -5
4 3
1 2
1 3
1 4
2 1 1
1 3 2
1 3 5
Sample Output
Case #1: 100
Case #2: 8
Solution
令subtree(i)表示以i为根节点的子树,对于某个宝箱a,b,v,令c=lca(a,b)
1.若a!=c且b!=c,那么要想得到这个宝箱的权值,起点必须在subtree(a)中,终点必须在subtree(b)中
2.若a=c,说明b是a的祖先,要想拿到此宝箱,起点不能在subtree(a)中,终点必须在subtree(b)中
3.若b=c,与a=c类似,起点不能在subtree(b)中,终点必须在subtree(b)中
4.若a=b,说明宝箱和其对应的钥匙在同一点,那么起点和终点只要不同的subtree(son)中即可(son为a的儿子节点)
考虑到每种情况起点终点都在某棵子树中任意选取,故首先求出每个节点的dfs序,用dfs序将一棵子树subtree(i)映射成一个连续的区间[l[i],r[i]],进而上面四种情况中起点和终点的选则就对应二维平面的一个矩阵,平面上点(x,y)的实际意义是走以x为起点y为终点的简单路径能够得到的权值,通过对每个宝箱的讨论可以在平面上扔若干个矩阵,最后从整个平面上找一个点权最大的点即为答案,后者用扫描线加线段树即可(即将一个矩阵看作两条横线,扫到下方横线就权值加v,扫到上方横线就权值减v表示消除影响)
注意前三种情况每种情况最多加两个矩阵,但是第四种情况由于起点终点可以在a的若干儿子子树中选取,故最多会出现n^2个矩阵,这显然不行,故可以采取总体累加,局部消除的方法,即先给加一个全平面矩阵,权值为v,然后加所有的不合法矩阵,权值为-v,不非法矩阵有四种情况:
1.起点终点都在a的某个subtree(son)中,son为a的儿子节点(累加至多n个)
2.起点终点都在suntree(left_brother)中,left_brother为a的左兄弟节点(即dfs序比l[a]小的点)(至多一个)
2.起点终点都在suntree(right_brother)中,right_brother为a的右兄弟节点(即dfs序比r[a]大的点)(至多一个)
3.起点和终点,一个在suntree(left_brother)中一个在suntree(right_brother)中(至多两个)
不合法矩阵累加不超过4*n个,这样就可以保证复杂度是O(nlogn)的了
对于上面第一种情况所加矩阵至多n个的解释:虽然每个节点的儿子节点数可以达到n-1个,但是如果将位于同一个节点处的宝箱看作一个来考虑的话,那么每个节点作为其父亲节点的儿子节点在这种情况下被加进去至多一次,故累加不超过n个,这里因为没有这么强的数据所以没有合并宝箱
Code
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
#define INF 0x3f3f3f3f
#define maxn 111111
struct node
{
int y,l,r,v;
node(){};
node(int _y,int _l,int _r,int _v)
{
y=_y,l=_l,r=_r,v=_v;
}
bool operator <(const node &b)const
{
return y<b.y;
}
}a[10*maxn];
vector<int>g[maxn];
int p[maxn][20],deep[maxn],vis[maxn];
int index,l[maxn],r[maxn];
int T,n,m,res,Case=1;
void dfs(int u)
{
l[u]=++index;
vis[u]=1;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(vis[v])continue;
deep[v]=deep[u]+1;
p[v][0]=u;
dfs(v);
}
r[u]=index;
}
int lca(int a,int b)
{
int i,j;
if(deep[a]<deep[b])swap(a,b);
for(i=0;(1<<i)<=deep[a];i++);
i--;
for(j=i;j>=0;j--)
if(deep[a]-(1<<j)>=deep[b])
a=p[a][j];
if(a==b) return a;
for(j=i;j>=0;j--)
{
if(p[a][j]!=-1&&p[a][j]!=p[b][j])
{
a=p[a][j];
b=p[b][j];
}
}
return p[a][0];
}
int find(int x,int step)
{
for(int i=0;i<20;i++)
if(step&(1<<i))
x=p[x][i];
return x;
}
void init(int n)
{
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i<=n;i++)
if(~p[i][j-1])
p[i][j]=p[p[i][j-1]][j-1];
}
void deal(int y1,int y2,int x1,int x2,int v)
{
a[res++]=node(y1,x1,x2,v);
if(y2<n)a[res++]=node(y2+1,x1,x2,-v);
}
#define ls (t<<1)
#define rs (t<<1|1)
struct Tree
{
int Max,lazy;
}tree[4*maxn];
void push_up(int t)
{
tree[t].Max=max(tree[ls].Max,tree[rs].Max);
}
void push_down(int t)
{
int lazy=tree[t].lazy;
if(!lazy)return ;
tree[ls].Max+=lazy,tree[ls].lazy+=lazy;
tree[rs].Max+=lazy,tree[rs].lazy+=lazy;
tree[t].lazy=0;
}
void build(int l,int r,int t)
{
tree[t].Max=tree[t].lazy=0;
if(l==r)return ;
int mid=(l+r)>>1;
build(l,mid,ls),build(mid+1,r,rs);
}
void update(int l,int r,int L,int R,int t,int v)
{
if(l==L&&r==R)
{
tree[t].Max+=v,tree[t].lazy+=v;
return ;
}
push_down(t);
int mid=(l+r)>>1;
if(R<=mid)update(l,mid,L,R,ls,v);
else if(L>mid)update(mid+1,r,L,R,rs,v);
else
{
update(l,mid,L,mid,ls,v);
update(mid+1,r,mid+1,R,rs,v);
}
push_up(t);
}
int main()
{
scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)g[i].clear();
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v),g[v].push_back(u);
}
memset(p,-1,sizeof(p));
memset(vis,0,sizeof(vis));
deep[1]=1;index=0;
dfs(1);
init(n);
res=0;
while(m--)
{
int a,b,v;
scanf("%d%d%d",&a,&b,&v);
int c=lca(a,b);
if(a==b)
{
deal(1,n,1,n,v);
for(int i=0;i<g[a].size();i++)
{
int d=g[a][i];
if(deep[d]<deep[a])continue;
deal(l[d],r[d],l[d],r[d],-v);
}
if(l[a]>1)deal(1,l[a]-1,1,l[a]-1,-v);
if(r[a]<n)deal(r[a]+1,n,r[a]+1,n,-v);
if(l[a]>1&&r[a]<n)
{
deal(1,l[a]-1,r[a]+1,n,-v);
deal(r[a]+1,n,1,l[a]-1,-v);
}
}
else if(a!=c&&b!=c)
{
deal(l[a],r[a],l[b],r[b],v);
}
else if(a==c)
{
int d=find(b,deep[b]-deep[a]-1);
if(l[d]>1)deal(1,l[d]-1,l[b],r[b],v);
if(r[d]<n)deal(r[d]+1,n,l[b],r[b],v);
}
else if(b==c)
{
int d=find(a,deep[a]-deep[b]-1);
if(l[d]>1)deal(l[a],r[a],1,l[d]-1,v);
if(r[d]<n)deal(l[a],r[a],r[d]+1,n,v);
}
}
sort(a,a+res);
int ans=-INF;
build(1,n,1);
for(int i=1,j=0;i<=n;i++)
{
for(;j<res&&a[j].y==i;j++)
update(1,n,a[j].l,a[j].r,1,a[j].v);
ans=max(ans,tree[1].Max);
}
printf("Case #%d: %d\n",Case++,ans);
}
return 0;
}