题意:给定N个点,找出乘积为k的路径,输出字典序最小的点对
考虑用哈希的做法,同样存从根到所有子节点的路径积。
考虑a*b = k,可以转换成a = k / b,但由于考虑MOD,可以先预处理出所有逆元。
#pragma comment(linker,"/STACK:102400000,102400000")
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
#include <stack>
#include <cmath>
#include <bitset>
#include <map>
using namespace std;
//#define ACM_LOCAL
typedef long long ll;
typedef long double ld;
typedef pair<int, int> PII;
const int N = 2e5 + 5;
const int INF = 0x3f3f3f3f;
const int MOD = 1e6 + 3;
int n, m, cnt, h[N], rt, sz[N], mx[N], vis[N], sum, ans, k, a[N], id[N];
ll d[N], dep[N];
int pd[1000010];
ll inv[1000010];
int node1, node2;
struct edge{
int to, next;
}e[N<<1];
void add(int u, int v) {
e[cnt].to = v;
e[cnt].next = h[u];
h[u] = cnt++;
}
void get_rt(int x, int fa) {
sz[x] = 1, mx[x] = 0;
for (int i = h[x]; ~i; i = e[i].next) {
int y = e[i].to;
if (vis[y] || y == fa) continue;
get_rt(y, x);
sz[x] += sz[y];
mx[x] = max(mx[x], sz[y]);
}
mx[x] = max(mx[x], sum - sz[x]);
if (mx[x] < mx[rt]) rt = x;
}
void get_d(int x, int fa) {
d[++d[0]] = dep[x];
id[d[0]] = x;
for (int i = h[x]; ~i; i = e[i].next) {
int y = e[i].to;
if (vis[y] || y == fa) continue;
dep[y] = dep[x] * a[y] % MOD;
get_d(y, x);
}
}
void cal(int x, int fa) {
queue<int> que;
dep[x] = a[x];
for (int i = h[x]; ~i; i = e[i].next) {
int y = e[i].to;
if (vis[y] || y == fa) continue;
dep[y] = a[y], d[0] = 0;
get_d(y, -1);
for (int j = 1; j <= d[0]; j++) {
int temp = (1ll*k * inv[d[j] * dep[x] % MOD]) % MOD;
if (pd[temp]) {
if (min(id[j], pd[temp]) < min(node1, node2)) {
node1 = id[j], node2 = pd[temp];
if (node1 > node2) swap(node1, node2);
}
else if (min(id[j], pd[temp]) == min(node1, node2) && max(id[j], pd[temp]) <= max(node1, node2)) {
node1 = id[j], node2 = pd[temp];
if (node1 > node2) swap(node1, node2);
}
}
}
for (int j = 1; j <= d[0]; j++) {
if (!pd[d[j]])
pd[d[j]] = id[j];
else
pd[d[j]] = min(pd[d[j]], id[j]);
que.push(d[j]);
}
}
while (que.size()) {
pd[que.front()] = 0;
que.pop();
}
}
void work(int x) {
vis[x] = 1, pd[1] = x;
cal(x, -1);
for (int i = h[x]; ~i; i = e[i].next) {
int y = e[i].to;
if (vis[y]) continue;
sum = sz[y], rt = 0;
get_rt(y, -1);
work(rt);
}
}
void get_inv() {
inv[1] = 1;
for(int i = 2; i <= MOD; i++)
inv[i] = ((MOD - MOD / i) * inv[MOD % i] + MOD) % MOD;
}
void solve () {
get_inv();
while (~scanf("%d %d", &n, &k)) {
memset(h, -1, sizeof h);
memset(vis, 0, sizeof vis);
node1 = node2 = INF;
cnt = 0;
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i <= n-1; i++) {
int x, y;
scanf("%d %d", &x, &y);
add(x, y);
add(y, x);
}
rt = 0, sum = n, mx[0] = INF, ans = 0;
get_rt(1, -1);
work(rt);
if (node1 == INF && node2 == INF) printf("No solution\n");
else printf("%d %d\n", node1, node2);
//for (int i = 0; i <= MOD; i++) if (pd[i] != 0) cout << i << endl;
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
#ifdef ACM_LOCAL
freopen("input", "r", stdin);
freopen("output", "w", stdout);
#endif
solve();
}