题目
题目描述
题目译自 JOISC 2020 Day4 T1「首都 / Capital City」
在 JOI 的国度有 个小镇,从 到 编号,并由 条双向道路连接。第 条道路连接了 和 这两个编号的小镇。
这个国家的国王现将整个国家分为 个城市,从 到 编号,每个城市都有附属的小镇,其中编号为 的小镇属于编号为 的城市。每个城市至少有一个附属小镇。
国王还要选定一个首都。首都的条件是该城市的任意小镇都只能通过属于该城市的小镇到达。
但是现在可能不存在这样的选址,所以国王还需要将一些城市进行合并。对于合并城市 和 ,指的是将所有属于 的小镇划归给 城。
你需要求出最少的合并次数。
输入格式
输入第一行两个整数 ,为小镇和城市的数量。
接下来的 行,每行两个整数 ,描述了 条道路。
再接下来的 行,每行一个整数 ,表示编号为 的小镇属于编号为 的城市。
输出格式
输出一行一个整数为最少的合并次数。
样例
样例输入 1
6 3
2 1
3 5
6 2
3 4
2 3
1
3
1
2
3
2
样例输出 1
1
样例说明 1
你可以对城市 和 进行合并,然后选定 为首都,因为最初任何城市都无法作为首都。总花费为 。
这个样例满足子任务 。
样例输入 2
8 4
4 1
1 3
3 6
6 7
7 2
2 5
5 8
2
4
3
1
1
2
3
4
样例输出 2
1
样例说明 2
这个样例满足子任务 。
样例输入 3
12 4
7 9
1 3
4 6
2 4
10 12
1 2
2 10
11 1
2 8
5 3
6 7
3
1
1
2
4
3
3
2
2
3
4
4
样例输出 3
2
样例说明 3
这个样例满足子任务 。
数据范围与提示
对于 的数据,,保证:
;
;
从任何一个小镇出发都能到达其他任何小镇;
;
对于每一个 ,存在一个 ,使得 。
详细子任务及附加限制如下表所示:
子任务编号 附加限制 分值
每个小镇最多可通过公路与两个小镇直接相连
无附加限制
思路
如果颜色i的虚树上有颜色j的点,i->j连一条边。
建出图后缩强联通分量,出度为0且点数最小的分量就是答案。
用倍增或者树链剖分优化这个建图即可做到O(n log n)
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<int> vi;
typedef pair<int,int> pii;
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define sz(a) int(a.size())
const int N = 4e5 + 10;
int gi() {
int x = 0,o = 1;
char ch = getchar();
while ((ch < '0' || ch > '9') && ch != '-') ch = getchar();
if (ch == '-')
o = -1,ch = getchar();
while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0',ch = getchar();
return x * o;
}
int n,k,siz[N],mn,rt,all,col[N],val[N],cnt[N],now_val[N];
ll sum,ans = 1e9;
vi E[N],vec[N],G[N],P;
bool vis[N],vv[N];
void add(int u,int v) {
E[u].pb(v);
E[v].pb(u);
}
void getroot(int u,int fa) {
siz[u] = 1;
int mx = 0;
for (auto v : E[u])
if (v != fa && !vis[v])
{
getroot(v,u);
siz[u] += siz[v];
mx = max(mx,siz[v]);
}
mx = max(mx,all - siz[u]);
if (mx < mn)
mn = mx,rt = u;
}
void dfs(int u,int fa) {
++cnt[col[u]];
P.pb(u);
if (fa)
G[u].pb(fa);
for (auto v : E[u])
if (v != fa && !vis[v])
dfs(v,u);
}
void dfs2(int u) {
if (vv[u])
return;
vv[u] = 1;
sum += now_val[u];
for (auto v : G[u]) dfs2(v);
}
void work(int u) {
getroot(u,0);
all = siz[u];
mn = 1e9;
getroot(u,0);
vis[u = rt] = 1;
P.clear();
dfs(u,0);
map<int,bool> addd;
for (auto x : P) {
if (cnt[col[x]] < sz(vec[col[x]])) {
now_val[x] = 1e9;
continue;
}
now_val[x] = val[x];
if (!addd.count(col[x])) {
addd[col[x]] = 1;
for (auto t : vec[col[x]]) G[n + col[x]].pb(t);
}
G[x].pb(n + col[x]);
}
sum = 0;
dfs2(u);
ans = min(ans,sum);
for (auto x : P) {
G[x].clear();
G[n + col[x]].clear();
vv[x] = vv[n + col[x]] = 0;
cnt[col[x]] = 0;
}
for (auto v : E[u])
if (!vis[v])
work(v);
}
int main() {
cin >> n >> k;
for (int i = 1,u,v; i < n; i++) u = gi(),v = gi(),E[u].pb(v),E[v].pb(u);
for (int i = 1; i <= n; i++) vec[col[i] = gi()].pb(i);
for (int i = 1; i <= k; i++) val[vec[i].back()] = 1;
work(1);
cout << ans - 1;
}