NOIP2016 天天爱跑步 线段树合并_桶_思维题
竟然独自想出来了,好开心
Code:
#include<bits/stdc++.h>
#define setIO(s) freopen(s".in","r",stdin)
#define maxn 400000
#define M 1000000
#define plus pl
#define minus mi
using namespace std;
vector<int>plus[M],minus[M];
int n,m;
int hd[maxn<<1],to[maxn<<1],nx[maxn<<1],edges;
int dep[maxn],F[22][maxn],tim[maxn],st[maxn],ed[maxn];
int answer[maxn];
void add(int u,int v){
nx[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
void dfs1(int u,int fa){
dep[u]=dep[fa]+1;
F[0][u]=fa;
for(int i=1;i<22;++i) F[i][u]=F[i-1][F[i-1][u]];
for(int v=hd[u];v;v=nx[v])
if(to[v]!=fa) dfs1(to[v],u);
}
int lca(int a,int b){
if(dep[a]>dep[b]) swap(a,b);
if(dep[a]!=dep[b]) for(int i=21;i>=0;--i) if(dep[F[i][b]]>=dep[a]) b=F[i][b];
if(a==b) return a;
for(int i=21;i>=0;--i) if(F[i][a]!=F[i][b]) a=F[i][a],b=F[i][b];
return F[0][a];
}
int tot,root[maxn];
struct Node{ int l,r,w; }node[maxn*11];
#define mid ((l+r)>>1)
void ins(int &o,int k,int delta,int l,int r){
if(!o) o=++tot;
node[o].w+=delta;
if(l==r) return;
else{
if(k<=mid) ins(node[o].l,k,delta,l,mid);
else ins(node[o].r,k,delta,mid+1,r);
}
}
int merge(int x,int y){
if(!x||!y) return x+y;
node[x].w+=node[y].w;
node[x].l=merge(node[x].l,node[y].l);
node[x].r=merge(node[x].r,node[y].r);
return x;
}
int query(int x,int l,int r,int k){
if(!x) return 0;
if(l==r) return node[x].w;
if(k<=mid) return query(node[x].l,l,mid,k);
else return query(node[x].r,mid+1,r,k);
}
void dfs2(int u){
for(int sz=plus[u].size(),i=0;i<sz;++i) ins(root[u],plus[u][i],1,1,M);
for(int sz=minus[u].size(),i=0;i<sz;++i) ins(root[u],minus[u][i],-1,1,M);
plus[u].clear(),minus[u].clear();
for(int v=hd[u];v;v=nx[v])
if(to[v]!=F[0][u]) dfs2(to[v]),root[u]=merge(root[u],root[to[v]]);
answer[u]+=query(root[u],1,M,tim[u]+dep[u]);
}
void dfs3(int u){
for(int sz=plus[u].size(),i=0;i<sz;++i) ins(root[u],plus[u][i],1,1,M);
for(int sz=minus[u].size(),i=0;i<sz;++i) ins(root[u],minus[u][i],-1,1,M);
for(int v=hd[u];v;v=nx[v])
if(to[v]!=F[0][u]) dfs3(to[v]),root[u]=merge(root[u],root[to[v]]);
answer[u]+=query(root[u],1,M,dep[u]-tim[u]+maxn);
}
void up(){
for(int i=1,c;i<=m;++i){
c=lca(st[i],ed[i]);
if(dep[c]<dep[st[i]]) {
plus[st[i]].push_back(dep[st[i]]);
minus[F[0][c]].push_back(dep[st[i]]);
}
}
dfs2(1);
memset(node,0,sizeof(node));
memset(root,0,sizeof(root));
tot=0;
}
void down(){
for(int i=1,c,pre;i<=m;++i) {
c=lca(st[i],ed[i]);
if(dep[c]<dep[ed[i]]) {
pre=dep[st[i]]-dep[c];
if(st[i]==c){
plus[ed[i]].push_back(dep[c]-pre+maxn);
minus[F[0][c]].push_back(dep[c]-pre+maxn);
}else {
plus[ed[i]].push_back(dep[c]-pre+maxn);
minus[c].push_back(dep[c]-pre+maxn);
}
}
}
dfs3(1);
}
int main(){
//setIO("input");
scanf("%d%d",&n,&m);
for(int i=1,a,b;i<n;++i) scanf("%d%d",&a,&b),add(a,b),add(b,a);
dfs1(1,0);
for(int i=1;i<=n;++i) scanf("%d",&tim[i]);
for(int i=1;i<=m;++i) {
scanf("%d%d",&st[i],&ed[i]);
if(st[i]==ed[i] && tim[st[i]]==0) ++answer[st[i]];
}
up(),down();
for(int i=1;i<=n;++i) printf("%d ",answer[i]);
return 0;
}