题目描述
原题来自:POJ 3417
Dark 是一张无向图,图中有 NNN 个节点和两类边,一类边被称为主要边,而另一类被称为附加边。Dark 有 N–1N–1N–1 条主要边,并且 Dark 的任意两个节点之间都存在一条只由主要边构成的路径。另外,Dark 还有 MMM 条附加边。
你的任务是把 Dark 斩为不连通的两部分。一开始 Dark 的附加边都处于无敌状态,你只能选择一条主要边切断。一旦你切断了一条主要边,Dark 就会进入防御模式,主要边会变为无敌的而附加边可以被切断。但是你的能力只能再切断 Dark 的一条附加边。
现在你想要知道,一共有多少种方案可以击败 Dark。注意,就算你第一步切断主要边之后就已经把 Dark 斩为两截,你也需要切断一条附加边才算击败了 Dark。
输入格式
第一行包含两个整数 NNN 和 MMM;
之后 N–1N – 1N–1 行,每行包括两个整数 AAA 和 BBB,表示 AAA 和 BBB 之间有一条主要边;
之后 MMM 行以同样的格式给出附加边。
输出格式
输出一个整数表示答案。
样例
样例输入
4 1
1 2
2 3
1 4
3 4
样例输出
3
数据范围与提示
对于 20%20\%20% 的数据,1≤N,M≤1001\le N,M\le 1001≤N,M≤100;
对于 100%100\%100% 的数据,1≤N≤105,1≤M≤2×1051\le N\le 10^5,1\le M\le 2\times 10^51≤N≤105,1≤M≤2×105。数据保证答案不超过 231−12^{31}-1231−1。
第一步:枚举一条去掉的主要边
第二步;将这条边去掉后,看以它深度较深的点为根的子树中有几个点与不属于它的点集相连
第三步:如果有0个,答案加上m,如果有1个答案加上1
重点在于第二步
可以预处理,用树上差分
GG啦
#include<cstdio>
#include<iostream>
#define ll long long
using namespace std;
int read()
{
int ret=0;
char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9')
ret=(ret<<1)+(ret<<3)+ch-'0',
ch=getchar();
return ret;
}
int n,m,cnt;
const int N=2e5+5;
int a[N],to[N],nxt[N],he[N],f[N][21],dep[N],lg[N];
ll ans;
inline void add(int u,int v)
{
to[++cnt]=v;
nxt[cnt]=he[u];
he[u]=cnt;
}
void dfs(int fa,int u)
{
dep[u]=(!fa)?0:dep[fa]+1;
f[u][0]=fa;
for(int i=1;i<=lg[dep[u]];i++)
f[u][i]=f[f[u][i-1]][i-1];
for(int e=he[u];e;e=nxt[e])
{
int v=to[e];
if(v!=fa) dfs(u,v);
}
}
int LCA(int u,int v)
{
if(dep[u]>dep[v]) swap(u,v);
while(dep[u]<dep[v])
v=f[v][lg[dep[v]-dep[u]]];
for(int i=lg[dep[u]];i>=0;i--)
if(f[u][i]!=f[v][i])
u=f[u][i],v=f[v][i];
if(u!=v) u=f[u][0];
return u;
}
void dfs1(int fa,int u)
{
for(int e=he[u];e;e=nxt[e])
{
int v=to[e];
if(v!=fa)
dfs1(u,v),a[u]+=a[v];
}
}
void dfss(int fa,int u)
{
for(int e=he[u];e;e=nxt[e])
{
int v=to[e];
if(v!=fa)
{
if(!a[v]) ans+=m;
else if(a[v]==1)ans++;
dfss(u,v);
}
}
}
int main()
{
n=read(),m=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
add(u,v); add(v,u);
}
lg[0]=-1;
for(int i=1;i<=n;i++)
lg[i]=lg[i>>1]+1;
dfs(0,1);
for(int i=1;i<=m;i++)
{
int u=read(),v=read();
int lca=LCA(u,v);
a[u]++,a[v]++,a[lca]--,a[lca]--;
}
ans=0;
dfs1(0,1);
dfss(0,1);
printf("%lld\n",ans);
return 0;
}