前言
虽然这道题并不难,但我还是挺激动的,因为我用上了牛客比赛上学的 一直吃灰 的换根
d
p
dp
dp,一激之下,猛敲出了这篇题解 (虽然正解比这简单多了…)
题解
我们可以这样翻译题面。
给出一个无向图,有的点标记为 'A', 有的点标记为 'B',要求断开任何一条割边后,分
出的两个联通块中,有一个块不包含 'A' 或 不包含 'B',求这样的割边的个数
我们先缩点,将每一个双联通分量收缩成一个点,由于图联通,我们可以得到一棵树,就像这样。
所以我们只需要枚举缩点后的每一条边(
T
a
Ta
Ta一定是割边),再统计两棵子树所含
′
A
′
'A'
′A′和
′
B
′
'B'
′B′ 的数量即可,暴力跑的话最坏时间复杂度
O
(
n
2
)
O (n ^ 2)
O(n2) ,一定超时,这时我们就要用上换根
d
p
dp
dp了。
我们只考虑计数
′
A
′
'A'
′A′, 因为
′
B
′
'B'
′B′同理
首先任选一个根节点。
记
d
p
d
o
n
w
[
i
]
dpdonw[i]
dpdonw[i]表示以
i
i
i为根节点的子树的
′
A
′
'A'
′A′ 的数量
记
d
p
u
p
[
i
]
dpup[i]
dpup[i]表示
i
i
i节点的所有非子孙节点的
′
A
′
'A'
′A′ 的数量。
状态转移方程:
第一种推导方法:根据图我们发现
d
p
u
p
[
u
]
=
d
p
d
o
w
n
[
f
a
]
+
d
p
u
p
[
f
a
]
−
d
p
d
o
w
n
[
u
]
dpup[u] = dpdown[fa] + dpup[fa] - dpdown[u]
dpup[u]=dpdown[fa]+dpup[fa]−dpdown[u] (规律大法好)
第二种推导方法:
对于此状态转移,我们还可以根据定义来走
d
p
d
o
w
n
[
f
a
]
+
d
p
u
p
[
f
a
]
dpdown[fa] + dpup[fa]
dpdown[fa]+dpup[fa]表示整棵树的
′
A
′
'A'
′A′ 节点的数量
根据容斥原理:整棵树的
′
A
′
'A'
′A′ 数量减去以
u
u
u 为根节点的子树的
′
A
′
'A'
′A′ 数量即为
d
p
u
p
[
u
]
dpup[u]
dpup[u]
统计分开的两颗子树也同理可以推出:
一棵子树是:
d
p
d
o
w
n
[
u
]
dpdown[u]
dpdown[u]
另一棵子树是:
d
p
u
p
[
f
a
]
+
d
p
d
o
w
n
[
f
a
]
−
d
p
d
o
w
n
[
u
]
dpup[fa] + dpdown[fa] - dpdown[u]
dpup[fa]+dpdown[fa]−dpdown[u]
#include <map>
#include <set>
#include <cmath>
#include <stack>
#include <queue>
#include <vector>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define LL long long
#define ULL unsigned long long
template <typename T> int read (T &x) {x = 0; T f = 1;char tem = getchar ();while (tem < '0' || tem > '9') {if (tem == '-') f = -1;tem = getchar ();}while (tem >= '0' && tem <= '9') {x = (x << 1) + (x << 3) + tem - '0';tem = getchar ();}x *= f; return 1;}
template <typename T> void write (T x) {if (x < 0) {x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> T Max (T x, T y) { return x > y ? x : y; }
template <typename T> T Min (T x, T y) { return x < y ? x : y; }
template <typename T> T Abs (T x) { return x > 0 ? x : -x; }
const int Maxn = 1e5;
const int Maxm = 1e6;
int n, m, k, l;
bool vis_a[Maxn + 5], vis_b[Maxn + 5];
int len = 1, Head[Maxn + 5];
struct edge {
int to, Next;
}e[Maxm * 2 + 5];
void add (int x, int y) {
e[++len].to = y;
e[len].Next = Head[x];
Head[x] = len;
}
int dfn[Maxn + 5], low[Maxn + 5], timestamp;
int st[Maxn + 5], Top;
int id[Maxn + 5], dcc_cnt;
int have_a[Maxn + 5], have_b[Maxn + 5];
void Tarjan (int u, int fa) {//缩点
dfn[u] = low[u] = ++timestamp;
st[++Top] = u;
for (int i = Head[u]; i; i = e[i].Next) {
int v = e[i].to;
if (dfn[v] == 0) {
Tarjan (v, i);
low[u] = Min (low[u], low[v]);
}
else if ((i ^ 1) != fa)
low[u] = Min (low[u], dfn[v]);
}
if (dfn[u] == low[u]) {
dcc_cnt++;
int v;
do {
v = st[Top--];
id[v] = dcc_cnt;
have_a[dcc_cnt] += vis_a[v];
have_b[dcc_cnt] += vis_b[v];
}while (u != v);
}
}
int depth[Maxn + 5];
int dp_down[3][Maxn + 5], dp_up[3][Maxn + 5];
//1,2分别求的是 'A' 的数量,和 'B' 的数量
vector <int> g[Maxn + 5];
void Dp_Down (int u, int fa) {//求出 dp_down
dp_down[1][u] += have_a[u];
dp_down[2][u] += have_b[u];
for (int i = 0; i < g[u].size (); i++) {
int v = g[u][i];
if (v == fa) continue;
depth[v] = depth[u] + 1;
Dp_Down (v, u);
dp_down[1][u] += dp_down[1][v];
dp_down[2][u] += dp_down[2][v];
}
}
void Dp_Up (int u, int fa) {//求出dp_up
if (fa != -1) {
dp_up[1][u] = dp_up[1][fa] + dp_down[1][fa] - dp_down[1][u];
dp_up[2][u] = dp_up[2][fa] + dp_down[2][fa] - dp_down[2][u];
}
for (int i = 0; i < g[u].size (); i++) {
int v = g[u][i];
if (v == fa) continue;
Dp_Up (v, u);
}
}
struct Dete {
int x, y;
}ed[Maxn + 5];
map <pair <int, int>, bool> mp;
signed main () {
read (n); read (m); read (k); read (l);
for (int i = 1; i <= k; i++) {
int x; read (x);
vis_a[x] = 1;
}
for (int i = 1; i <= l; i++) {
int x; read (x);
vis_b[x] = 1;
}
for (int i = 1; i <= m; i++) {
int x, y; read (x); read (y);
ed[i].x = x; ed[i].y = y;
add (x, y); add (y, x);
}
for (int i = 1; i <= n; i++) {
if (dfn[i] == 0) {
Tarjan (i, -1);
}
}
for (int u = 1; u <= n; u++) {
for (int i = Head[u]; i; i = e[i].Next) {
int v = e[i].to;
if (id[u] == id[v]) continue;
pair <int, int> tem = make_pair (id[u], id[v]);
if (mp.find (tem) != mp.end ()) continue;
mp[tem] = 1;
g[id[u]].push_back (id[v]);
}
}
depth[1] = 1;
Dp_Down (1, -1);
Dp_Up (1, -1);
int ans = 0;
for (int i = 1; i <= m; i++) {
int x = id[ed[i].x], y = id[ed[i].y];
if (x == y) continue;
if (depth[y] >= depth[x]) swap (x, y);
if (dp_down[1][x] == 0 || dp_down[2][x] == 0) ans++;//子树为零
else if (dp_up[1][y] + dp_down[1][y] - dp_down[1][x] == 0) ans++;//去掉子树后'A'为零
else if (dp_up[2][y] + dp_down[2][y] - dp_down[2][x] == 0) ans++;//去掉子树后'B'为零
}
printf ("%d", ans);
return 0;
}