Description
给出一棵大小为n的树,以及m条祖先后代链,选择第i条边会付出代价ci,求选择代价最小的边使得覆盖整棵树。
n<=3*1e5
Solution
“愿春死樱花下,释迦入灭日。後人悼我,当奉佛樱花。”
额原谅我中二了,不过只是喜欢上一只忘却了过去的亡灵而已。。。。
栋栋搬的好题。。。原题CF671D
栋栋有一种将原问题对偶之后的贪心做法,然而我不会(其实是懒得看)。。。
考虑最简单的Dp,设F[i]表示i还能再往上伸,i的子树中的答案。
那么每一条祖先后代链的链底都会对链上除了链顶的每个节点贡献,贡献还要加上这条路径上其他点的其他儿子的f值。
考虑用线段树维护这个东西,设sum[x]表示x所有儿子的f值和,那么贡献可以写成∑sum[x]-f[x]的形式
用线段树维护dfs序,在每个链底维护这个值,每次从x回溯时对整个子树中可行的链底进行贡献
这个可以用区间加,然后把不可行的赋值成+∞来实现
然后用一个东西来维护每条可能贡献的链就好了。。。
由于比较懒写了vector。。。
Code
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define rep(i,a) for(int i=last[a];i;i=next[i])
using namespace std;
typedef long long ll;
int read() {
char ch;
for(ch=getchar();ch<'0'||ch>'9';ch=getchar());
int x=ch-'0';
for(ch=getchar();ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x;
}
const int N=3*1e5+5;
const ll inf=1e15;
int n,m,x,y,tot,dfn[N],size[N],w[N],c[N],d[N];
ll f[N],sum[N],tr[N*4],lazy[N*4];
void build(int v,int l,int r) {
tr[v]=inf;
if (l==r) return;
int m=(l+r)/2;
build(v*2,l,m);
build(v*2+1,m+1,r);
}
void back(int v,ll z) {
tr[v]+=z;lazy[v]+=z;
}
void down(int v) {
if (lazy[v]) {
back(v*2,lazy[v]);
back(v*2+1,lazy[v]);
lazy[v]=0;
}
}
ll query(int v,int l,int r,int x,int y) {
if (l==x&&r==y) return tr[v];
int m=(l+r)/2;down(v);
if (y<=m) return query(v*2,l,m,x,y);
else if (x>m) return query(v*2+1,m+1,r,x,y);
else return min(query(v*2,l,m,x,m),query(v*2+1,m+1,r,m+1,y));
}
void modify(int v,int l,int r,int x,int y,ll z) {
if (l==x&&r==y) {back(v,z);return;}
int m=(l+r)/2;down(v);
if (y<=m) modify(v*2,l,m,x,y,z);
else if (x>m) modify(v*2+1,m+1,r,x,y,z);
else modify(v*2,l,m,x,m,z),modify(v*2+1,m+1,r,m+1,y,z);
tr[v]=min(tr[v*2],tr[v*2+1]);
}
void change(int v,int l,int r,int x,ll z) {
if (l==r) {tr[v]=z;return;}
int m=(l+r)/2;down(v);
if (x<=m) change(v*2,l,m,x,z);
else if (x>m) change(v*2+1,m+1,r,x,z);
tr[v]=min(tr[v*2],tr[v*2+1]);
}
typedef vector<int> vec;
#define pb(a) push_back(a)
vec in[N],out[N];
int point[N];
bool cmp(int x,int y) {return dfn[d[x]]>dfn[d[y]]||dfn[d[x]]==dfn[d[y]]&&c[x]<c[y];}
int last[N],next[N*2],t[N*2],l;
void add(int x,int y) {
t[++l]=y;next[l]=last[x];last[x]=l;
}
void dfs(int x,int y) {
dfn[x]=++tot;size[x]=1;w[tot]=x;
rep(i,x) if (t[i]!=y) dfs(t[i],x),size[x]+=size[t[i]];
}
void dp(int x,int y) {
f[x]=inf;
rep(i,x) if (t[i]!=y) dp(t[i],x),sum[x]+=f[t[i]];
if (f[1]==-1) return;
if (!in[x].empty()) {
int now=in[x][0];
change(1,1,n,dfn[d[now]],c[now]);
}
if (!out[x].empty()&&x!=1)
fo(i,0,out[x].size()-1) {
int now=out[x][i];
if (point[d[now]]==in[d[now]].size()-1) change(1,1,n,dfn[d[now]],inf);
else {
int suf=in[d[now]][++point[d[now]]];
modify(1,1,n,dfn[d[now]],dfn[d[now]],-c[now]);
modify(1,1,n,dfn[d[suf]],dfn[d[suf]],c[suf]);
}
}
f[x]=query(1,1,n,dfn[x],dfn[x]+size[x]-1)+sum[x];
modify(1,1,n,dfn[x],dfn[x]+size[x]-1,sum[x]-f[x]);
if (f[x]>inf) {f[1]=-1;return;}
}
int main() {
freopen("youmu.in","r",stdin);
freopen("youmu.out","w",stdout);
n=read();m=read();
fo(i,1,n-1) {
x=read();y=read();
add(x,y);add(y,x);
}
fo(i,1,m) {
x=read();y=read();
in[x].pb(i);out[y].pb(i);
c[i]=read();d[i]=x;
}
dfs(1,0);
fo(i,1,n) sort(in[i].begin(),in[i].end(),cmp);
build(1,1,n);dp(1,0);
printf("%lld\n",f[1]);
}