题意:
给出一个无向图,分别给出n-1条树边(主要边)和m条非树边(附加边),这个无向图可以看做一棵树外加m条附加边,你可以切
断一条主要边和一条附加边,求切割后,能够使这个无向图不再连通的切割方案数(即使只切断一条主要边就可以使图不连通,你
也需要再切断一条附加边)
分析:
我们先考虑只有一条附加边(x,y)时,这时这张图就是一棵基环树
我们发现如果x,y之间有一条附加边,则这条边和x到y的路径组成了一个环,如果说我们要切割x到y的路径上的一条主要边,我们
必须要再切断这条附加边,才能使图不再连通
那么如果x,y之间有两条或以上附加边,若我们已切割了x到y路径上的一条主要边,那么是无法通过仅再切割一条附加边来使图不再连通
因而我们每次读入一条附加边,就给x到y的路径上的所有主要边记录上“被覆盖一次”,这样再去遍历所有主要边
对于我们想要切割的一条主要边,有以下3种情况
若这条边被覆盖0次,则可以任意再切断一条附加边
若这条边被覆盖1次,那么只能再切断唯一的一条附加边
若这条边被覆盖2次及以上,没有可行的方案
现在的问题是如何快速求出每条边被覆盖了多少次,对于这类问题,可以类比序列差分,有树上差分算法
设差分数组dif初值为0,若x,y有一条附加边,则dif[x]++,dif[y]++,dif[lca(x,y)]-=2
设f(x)为以x为根的子树中所有节点dif之和,则f(x)就是x到其父节点的边被覆盖的次数
参考博客:https://blog.csdn.net/Fantasy_World/article/details/80544982
/*
*
* LCA 在线算法
*/
#include<cstdio>
#include<queue>
#include<cstring>
using namespace std;
const int maxn = 1e5+10;
const int DEG = 20;
struct Edge{
int to,next;
}edge[maxn*2];
int head[maxn],tot;
void addedge(int u,int v){
edge[tot].to = v;
edge[tot].next = head[u];
head[u] = tot++;
}
void init(){
tot = 0;
memset(head,-1,sizeof(head));
}
int fa[maxn][DEG+5];//fa[i][j]表示结点i的第2^j个祖先
int deg[maxn];//深度数组
void bfs(int root){
queue<int>que;
deg[root] = 0;
fa[root][0] = root;
que.push(root);
while(!que.empty()){
int tmp = que.front();
que.pop();
for(int i = 1;i < DEG;i++)
fa[tmp][i] = fa[fa[tmp][i-1]][i-1];
for(int i = head[tmp]; i != -1;i = edge[i].next){
int v = edge[i].to;
if(v == fa[tmp][0])continue;
deg[v] = deg[tmp] + 1;
fa[v][0] = tmp;
que.push(v);
}
}
}
int LCA(int u,int v){
if(deg[u] > deg[v])swap(u,v);
int hu = deg[u], hv = deg[v];
int tu = u, tv = v;
for(int det = hv-hu, i = 0; det ;det>>=1, i++)
if(det&1)
tv = fa[tv][i];
if(tu == tv)return tu;
for(int i = DEG-1; i >= 0; i--){
if(fa[tu][i] == fa[tv][i])
continue;
tu = fa[tu][i];
tv = fa[tv][i];
}
return fa[tu][0];
}
int dif[maxn];
int f[maxn];
bool vis[maxn];
int dfs(int u){
vis[u]=1;
f[u]=dif[u];
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].to;
if(vis[v]) continue;
f[u]+=dfs(v);
}
return f[u];
}
int main(){
int n,m;
while(~scanf("%d%d",&n,&m)){
init();
int u,v;
for(int i=1;i<n;i++){
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
memset(dif,0,sizeof dif);
memset(f,0,sizeof f);
memset(vis,0,sizeof vis);
bfs(1);
for(int i=0;i<m;i++){
scanf("%d%d",&u,&v);
dif[u]++;dif[v]++;dif[LCA(u,v)]-=2;
}
// for(int i=1;i<=n;i++){
// printf("%d ",dif[i]);
// }
// printf("\n");
dfs(1);
// printf("asdasd\n");
int cnt1=0,cnt0=0;
for(int i=2;i<=n;i++){
if(f[i]==1) cnt1++;
if(f[i]==0) cnt0++;
}
printf("%d\n",cnt1+cnt0*m);
}
return 0;
}
/*
7 1
1 2
2 3
3 4
3 5
4 6
5 7
4 5
*/