树上差分+树剖
tmp[i]:i到它父亲的边被覆盖的的次数
主边树剖,对于每一条虚边起点tmp[s]++,终点tmp[t]++,tmp[lca(s,t)]-=2,最后从tmp[i]!=0的点开始向上更新到根
1.如果当前节点i和fa[i]之间的边没有被虚边覆盖,砍掉这条边+任意一条虚边,方案数为m
2.如果这条边被虚边覆盖,当覆盖数==1时,只有砍掉这条边和这条虚边才满足条件,方案数为1
3.覆盖虚边数>1,对答案无贡献
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 100005
using namespace std;
int n,m;
int tmp[N],num[N];
struct edge
{
int be,to,ne;
}b[N*2],s[N*2];
int head[N],k=0;
void add(int u,int v)
{
k++;
b[k].to=v;b[k].ne=head[u];head[u]=k;
}
int fa[N],size[N],son[N],d[N],tp[N];
void dfs1(int x,int father,int deep)
{
fa[x]=father;d[x]=deep;
size[x]=1;son[x]=0;
for(int i=head[x];i!=-1;i=b[i].ne)
if(b[i].to!=father)
{
dfs1(b[i].to,x,deep+1);
size[x]+=size[b[i].to];
if(size[son[x]]<size[b[i].to]) son[x]=b[i].to;
}
}
void dfs2(int x,int top)
{
tp[x]=top;
if(son[x]) dfs2(son[x],top);
for(int i=head[x];i!=-1;i=b[i].ne)
if(b[i].to!=fa[x]&&b[i].to!=son[x])
dfs2(b[i].to,b[i].to);
}
int lca(int x,int y)
{
int ans=0;
int fx=tp[x],fy=tp[y];
//<<x<<" "<<y<<endl;
while(fx!=fy)
{
if(d[fx]<d[fy]) {swap(fx,fy);swap(x,y);}
x=fa[fx];
fx=tp[x];
}
if(d[x]<d[y]) return x;
else return y;
}
void bl(int x)
{
int ad=tmp[x];
while(x!=1&&x)
{
num[x]+=ad;
x=fa[x];
}
}
void pre()
{
for(int i=1;i<=m;i++)
{
tmp[s[i].be]++;
tmp[s[i].to]++;
tmp[lca(s[i].be,s[i].to)]-=2;
}
for(int i=2;i<=n;i++)
if(tmp[i]!=0)
bl(i);
}
int main()
{
freopen("yam.in","r",stdin);
freopen("yam.out","w",stdout);
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
int q,w;
for(int i=1;i<n;i++)
{
scanf("%d%d",&q,&w);
add(q,w);add(w,q);
}
dfs1(1,0,0);
dfs2(1,1);
for(int i=1;i<=m;i++)
scanf("%d%d",&s[i].be,&s[i].to);
pre();
int ans=0;
for(int i=2;i<=n;i++)
{
if(num[i]==0) ans+=m;
if(num[i]==1) ans+=1;
}
printf("%d",ans);
//while(1);
return 0;
}