线段树合并+ 树上差分
线段树合并的复杂度:
两棵线段树节点个数的总和减去合并之后线段树的节点个数。
#include <bits/stdc++.h>
using namespace std;
const int N = 20000005;
const int maxn = 100005;
const int maxm = 200005;
int tree[N][2],ls[N],rs[N],tot;
void push_up(int x){
if( tree[ ls[x] ][0] > tree[ rs[x] ][0] ) {
tree[x][0] = tree[ls[x]][0];
tree[x][1] = tree[ ls[x] ][1];
}else if( tree[rs[x]][0] > tree[ls[x]][0] ){
tree[x][0] = tree[rs[x]][0];
tree[x][1] = tree[rs[x]][1];
}else{
tree[x][0] = tree[rs[x]][0];
tree[x][1]=min(tree[ls[x]][1],tree[rs[x]][1]);
}
}
void update( int left,int right,int l,int r,int x,int v ){
if( left <= l && right >= r ){
tree[x][0] += v;
tree[x][1] = l;
return;
}
int mid = l+r>>1;
if( left <= mid ){
if( !ls[x] ) ls[x] = ++tot;
update( left,right,l,mid,ls[x],v );
}
if( right > mid ){
if(!rs[x]) rs[x] = ++tot;
update( left,right,mid+1,r,rs[x],v );
}
push_up(x);
}
int merge( int l,int r,int x,int y ){
if(!x || !y) return max(x,y);
if( l == r ){
tree[x][0] += tree[y][0];
tree[x][1] = l;
return x;
}
int mid = l+r>>1;
ls[x] = merge(l,mid,ls[x],ls[y]);
rs[x] = merge(mid+1,r,rs[x],rs[y]);
push_up(x);
return x;
}
int he[maxn],ver[maxm],ne[maxm],tot2;
void add( int x,int y ){
ver[++tot2] = y;
ne[tot2] = he[x];
he[x] = tot2;
}
int fa[maxn][21],vis[maxn],d[maxn];
queue<int> que;
void build(){
que.push(1);
vis[1] = 1;d[1] = 1;
while(que.size()){
int x = que.front();
que.pop();
for( int cure = he[x];cure;cure= ne[cure] ){
int y = ver[cure];
if(vis[y]) continue;
vis[y] = 1;
fa[y][0] = x;
d[y] = d[x]+1;
for( int i = 1;i <= 20;i++ ){
fa[y][i] = fa[fa[y][i-1]][i-1];
}
que.push(y);
}
}
}
int lca( int x,int y ){
if( d[y] > d[x] ) swap(x,y);
for( int i = 20;i >= 0;i-- ){
if( d[fa[x][i]] >= d[y] ) x = fa[x][i];
}
if(x==y) return x;
for( int i = 20;i >= 0;i-- ){
if( fa[x][i] != fa[y][i] ){
x =fa[x][i];
y =fa[y][i];
}
}
return fa[x][0];
}
const int mx= 100000;
int root[maxn],du[maxn],ans[maxn];
int main(){
int n,m;
scanf("%d%d",&n,&m);
for( int x,y,i = 1;i < n;i++ ){
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
du[x]++;du[y]++;
}
build();
for( int x,y,z,i = 1;i <= m;i++ ){
scanf("%d%d%d",&x,&y,&z);
int w = lca(x,y);
if(!root[x]) root[x] = ++tot;
update( z,z,1,mx,root[x],1 );
if(!root[y]) root[y] =++tot;
update(z,z,1,mx,root[y],1);
if(!root[w])root[w]=++tot;
update(z,z,1,mx,root[w],-1);
if(!root[fa[w][0]])root[fa[w][0]]=++tot;
update(z,z,1,mx,root[fa[w][0]],-1);
}
for( int i = 2;i <= n;i++ ){
if(du[i]==1 ) {
que.push(i);
ans[i] = tree[root[i]][1];
}
}
if( !du[1] ) {
que.push(1);
ans[1] = tree[root[1]][1];
}
while(que.size()){
int x = que.front();
que.pop();
root[fa[x][0]] = merge( 1,mx,root[fa[x][0]],root[x] );
du[fa[x][0]]--;
if( ((du[fa[x][0]] == 1 && fa[x][0] != 1) || (fa[x][0] == 1 && du[1] == 0) ) ) {
que.push(fa[x][0]);
ans[ fa[x][0] ] = tree[root[fa[x][0]]][1];
}
}
for( int i = 1;i <= n;i++ ){
printf("%d\n",ans[i]);
}
return 0;
}