具体解法就是bc首页的解法了,排个颜色序和dfs序,然后按照解法一路解下去就行了。注意一定要pai排dfs序,不然会更新错误的公共祖先节点。
代码:
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
#include <stack>
#include <map>
#include <string>
#include <algorithm>
#define ll long long int
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define delf int m=(l+r)>>1
using namespace std;
int f[50005]; //当前集合的父节点为当前访问的子树公共祖先
int mark[50005];
vector <int> m[50005];
vector <int> q[50005];
int ans[50005];
int id[50005]; //dfs序
int n,m1,cnt;
struct color
{
int v;
int c;
} c[500005];
bool cmp(color a,color b)
{
if (a.c!=b.c)
return a.c<b.c;
return id[a.v]<id[b.v];
}
void init()
{
for (int i=1;i<=n;i++)
{
m[i].clear();
q[i].clear();
}
for (int i=1;i<=n;i++)
ans[i]=0;
for (int i=1;i<=n;i++)
f[i]=i;
for (int i=1;i<=n;i++)
mark[i]=0;
}
void dfs(int u,int t)
{
id[u]=cnt++;
for (int i=0;i<m[u].size();i++)
{
if (m[u][i]==t)
continue ;
dfs(m[u][i],u);
}
}
int find(int x)
{
if (f[x]==x)
return x;
return f[x]=find(f[x]);
}
int LCA(int u,int t)
{
for (int i=0;i<m[u].size();i++)
{
if (m[u][i]==t)
continue ;
int s=LCA(m[u][i],u);
ans[u]+=s;
f[m[u][i]]=u;
}
for (int i=0;i<q[u].size();i++)
{
if (mark[q[u][i]]==1)
ans[find(q[u][i])]--;
}
mark[u]=1;
return ans[u];
}
int scan() { //输入外挂
int res = 0, flag = 0;
char ch;
if((ch = getchar()) == '-') flag = 1;
else if(ch >= '0' && ch <= '9') res = ch - '0';
while((ch = getchar()) >= '0' && ch <= '9')
res = res * 10 + (ch - '0');
return flag ? -res : res;
}
void out(int a) { //输出外挂
if(a < 0) { putchar('-'); a = -a; }
if(a >= 10) out(a / 10);
putchar(a % 10 + '0');
}
int main()
{
while (~scanf("%d%d",&n,&m1))
{
init();
for (int i=1;i<n;i++)
{
int a,b;
a=scan();
b=scan();
//scanf("%d%d",&a,&b);
m[a].push_back(b);
m[b].push_back(a);
}
for (int i=1;i<=m1;i++)
{
int a,b;
a=scan();
b=scan();
//scanf("%d%d",&a,&b);
c[i].v=a;
c[i].c=b;
}
cnt=1;
dfs(1,0); //获得每个节点dfs序位置
sort(c+1,c+m1+1,cmp); //按照颜色第一关键字,dfs序第二关键字排序,这样能保证颜色相同的节
//点在访问时是顺序的,否则会因为减1的节点位置不对而出错。
ans[c[1].v]++;
for (int i=2;i<=m1;i++)
{
if (c[i].c==c[i-1].c&&c[i].v==c[i-1].v)
continue ;
ans[c[i].v]++;
if (c[i].c==c[i-1].c)
{
q[c[i].v].push_back(c[i-1].v);
q[c[i-1].v].push_back(c[i].v);
}
}
int t=LCA(1,0);
for (int i=1;i<n;i++)
{
out(ans[i]);
putchar(' ');
}
//printf("%d ",ans[i]);
printf("%d\n",ans[n]);
}
}