Description
真不知道出题人把题意搞那么长有何目的
给定一棵
N
N
个节点,以 为根的树。第
i
i
个节点的初始权值为 。
然后需要修改一些点的权值(可以改成小数),使得对于每个节点
i
i
都满足两点限制条件:
(1) 的所有子节点的权值相等。
(2)
i
i
的权值等于 所有子节点的权值之和。
求最少要修改多少个节点的权值。
Input
第一行是一个正整数
N
N
。接下来 行,每行一个正整数,其中的第
i
i
行表示 。再接下来是
N−1
N
−
1
行,每行两个正整数
a,b
a
,
b
表示
a
a
是 的父亲。
N<500000,A[i]<108
N
<
500000
,
A
[
i
]
<
10
8
Output
输出文件仅包含一行,表示最少要修改权值的节点个数。
Sample Input
5
5
4
3
2
1
1 2
1 3
2 4
2 5
Sample Output
3
HINT
一个最优解是将 A[1] A [ 1 ] 改成 8 8 , 改成 4 4 , 改成 2 2 。这样就能满足所有限制条件。
Solution
假设将根节点的权值修改成 ,
i
i
号节点的子节点个数为 。
那么根的子节点的权值都是
Td[1]
T
d
[
1
]
。
以此类推,得出一个重要的性质:确定了根节点的权值之后整棵树的权值都确定。
另外,也可以得出,设
prod[i]
p
r
o
d
[
i
]
为节点
i
i
到根的路径上(不包括 但包括
1
1
)的节点的 之积,
那么节点
i
i
的权值就是 。
接下来的思路是:一对点
i,j
i
,
j
,假设
i
i
的权值不做修改, 的权值也不做修改,那么满足的条件就是:
存在一个
T
T
,满足
显然 T=prod[i]×A[i] T = p r o d [ i ] × A [ i ] 。于是条件转化为:
移项后得到:
于是,就可以考虑把每个节点按照 prod×A p r o d × A 为关键字排序。
这样答案就是 n−满足prod×A全部相等的最长区间的长度 n − 满 足 p r o d × A 全 部 相 等 的 最 长 区 间 的 长 度 。
考虑到 prod×A p r o d × A 的值很大,对每个 prod×A p r o d × A 都用几个 109 10 9 级别的数取模。此外,模数取质数能够减小冲突的概率。
Code
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define For(i, a, b) for (i = a; i <= b; i++)
#define Edge(u) for (int e = adj[u], v; e; e = nxt[e])
using namespace std;
inline int read() {
int res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
return bo ? ~res + 1 : res;
}
const int N = 5e5 + 5, MX = 1e9 + 7, PYZ = 998244353, LPF = 1e8 + 7;
int n, a[N], ecnt, nxt[N], adj[N], go[N], d[N], p1[N], p2[N], p3[N], seq[N], ans;
void add_edge(int u, int v) {
nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
}
void dfs(int u, int fu) {
Edge(u) d[u]++; p1[u] = 1ll * p1[fu] * d[fu] % MX;
p2[u] = 1ll * p2[fu] * d[fu] % PYZ; p3[u] = 1ll * p3[fu] * d[fu] % LPF;
Edge(u) dfs(v = go[e], u);
}
inline bool comp(const int &x, const int &y) {
return p1[x] < p1[y] || (p1[x] == p1[y] && (p2[x] < p2[y]
|| (p2[x] == p2[y] && p3[x] < p3[y])));
}
int main() {
int i, x, y, nxt; n = read(); For (i, 1, n) a[seq[i] = i] = read();
p1[0] = p2[0] = p3[0] = d[0] = 1;
For (i, 1, n - 1) x = read(), y = read(), add_edge(x, y);
dfs(1, 0); For (i, 1, n) p1[i] = 1ll * p1[i] * a[i] % MX,
p2[i] = 1ll * p2[i] * a[i] % PYZ, p3[i] = 1ll * p3[i] * a[i] % LPF;
sort(seq + 1, seq + n + 1, comp); for (i = 1; i <= n;) {
nxt = i; while (i <= n && p1[seq[nxt]] == p1[seq[i]] &&
p2[seq[nxt]] == p2[seq[i]] && p3[seq[nxt]] == p3[seq[i]]) i++;
ans = max(ans, i - nxt);
}
cout << n - ans << endl; return 0;
}