题目地址
题意:你有n个节点,n-1条边构成一棵树,有m个猴子,猴子要在节点上,并且每个猴子至少要和另一只猴子有边相连,求最多需要多少条边
思路:最好的应该是猴子都是两两匹配的,因为每2个猴子相连的话,每只猴子的代价为0.5条边,但是超过的话就是多出来的每只猴子的代价为1条边。所以就能想到可以用二分匹配求出二个猴子的集团数有多少个,然后剩下的猴子随便加一条边连上一个集团就好了(因为代价都是为1条边)。虽然图的二分匹配的复杂度不能达到,但是因为是树,就只有一个联通块,所以dfs一次O(n)就可以了。
吐槽:这题一定要用fread,要不然就是TLE,有点搞不懂出题人。
#include <iostream>
#include <cstring>
#include <string>
#include <queue>
#include <vector>
#include <map>
#include <set>
#include <stack>
#include <cmath>
#include <cstdio>
#include <algorithm>
#define N 100010
#define LL __int64
#define inf 0x3f3f3f3f
#define lson l,mid,ans<<1
#define rson mid+1,r,ans<<1|1
#define getMid (l+r)>>1
#define movel ans<<1
#define mover ans<<1|1
using namespace std;
const LL mod = 1e9 + 7;
int n, m;
namespace fastIO {
#define BUF_SIZE 1000000
//fread -> read
bool IOerror = 0;
inline char nc() {
static char buf[BUF_SIZE], *p1 = buf + BUF_SIZE, *pend = buf + BUF_SIZE;
if (p1 == pend) {
p1 = buf;
pend = buf + fread(buf, 1, BUF_SIZE, stdin);
if (pend == p1) {
IOerror = 1;
return -1;
}
}
return *p1++;
}
inline bool blank(char ch) {
return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t';
}
inline void read(int &x) {
char ch;
while (blank(ch = nc()));
if (IOerror)
return;
for (x = ch - '0'; (ch = nc()) >= '0' && ch <= '9'; x = x * 10 + ch - '0');
}
#undef BUF_SIZE
};
using namespace fastIO;
int head[N];
struct node {
int to, next;
}edge[N << 1];
int cnt;
struct Hungarian {
int mark[N];//该边与哪个点构成的边为匹配边
void init() {
cnt = 0;
memset(head, -1, sizeof(head));
memset(mark, 0, sizeof(mark));
}
void add(int a, int b) {
edge[cnt].to = b;
edge[cnt].next = head[a];
head[a] = cnt++;
edge[cnt].to = a;
edge[cnt].next = head[b];
head[b] = cnt++;
}
void dfs(int u, int pre) {
int ans = 0, uu = 0;
for (int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].to;
if (pre == v) {
continue;
}
dfs(v, u);
ans++;//多少个子节点
uu += mark[v];//子节点有多少个已经匹配
}
if (ans - uu >= 1) cnt++, mark[u] = 1;
}
}hungarian;
int main() {
int a, b;
int T;
read(T);
//cin >> T;
while (T--) {
//cin >> n >> m;
read(n);
read(m);
hungarian.init();
for (int i = 2; i <= n; i++) {
//cin >> a;
read(a);
hungarian.add(i, a);
}
cnt = 0;
hungarian.dfs(1, 0);
if (cnt * 2 >= m) {
printf("%d\n", (m + 1) / 2);
}
else {
printf("%d\n", cnt + m - cnt * 2);
}
}
return 0;
}