题意:
给定一棵含 n n n 个结点的树,每个结点上有一个数值 a i a_i ai,判断是否存在点 ( u , v ) (u,~v) (u, v) 满足其路径上的结点数值乘积模 1 e 6 + 3 1e6+3 1e6+3 为 k k k,若存在则输出字典序最小的一对。( n ≤ 1 e 5 n\leq 1e5 n≤1e5)
链接:
https://vjudge.net/problem/HDU-4812
解题思路:
点分治,处理 r t rt rt 为根的子树答案时,逐一加入 r t rt rt 每棵子树的路径信息,利用乘法逆元查表。
参考代码:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
#define pb push_back
#define sz(a) ((int)a.size())
typedef long long ll;
typedef pair<int, int> pii;
const int maxn = 2e6 + 5;
const int mod = 1e6 + 3;
const int inf = 0x3f3f3f3f;
vector<int> G[maxn];
int vis[maxn], inv[maxn], a[maxn], siz[maxn], mp[maxn];
pii dis[maxn];
int rub[maxn], top;
int n, k, tn, rt, rmn, tot, p1, p2;
void getRt(int u, int f){
int mx = 0; siz[u] = 1;
for(auto v : G[u]){
if(vis[v] || v == f) continue;
getRt(v, u);
siz[u] += siz[v];
mx = max(mx, siz[v]);
}
mx = max(mx, tn - siz[u]);
if(mx < rmn) rmn = mx, rt = u;
}
void dfs(int u, int f, int val){
dis[++tot] = {val, u};
for(auto v : G[u]){
if(vis[v] || v == f) continue;
dfs(v, u, val * 1ll * a[v] % mod);
}
}
void init(){
inv[1] = 1;
for(int i = 2; i < mod; ++i) inv[i] = (mod - mod / i) * 1ll * inv[mod % i] % mod;
}
void check(int x, int y){
if(x == y) return;
if(x > y) swap(x, y);
if(x < p1 || x == p1 && y < p2) p1 = x, p2 = y;
}
void cal(int u){
mp[1] = u;
for(auto v : G[u]){
if(vis[v]) continue;
tot = 0;
dfs(v, u, a[v]);
for(int i = 1; i <= tot; ++i){
int x = k * 1ll * inv[dis[i].first] % mod * inv[a[u]] % mod;
if(mp[x]) check(dis[i].second, mp[x]);
}
for(int i = 1; i <= tot; ++i){
if(!mp[dis[i].first] || mp[dis[i].first] > dis[i].second) mp[dis[i].first] = dis[i].second;
rub[++top] = dis[i].first;
}
}
mp[1] = 0;
while(top) mp[rub[top--]] = 0;
}
void dfz(int u){
vis[u] = 1;
cal(u);
for(auto v : G[u]){
if(vis[v]) continue;
rmn = inf, tn = siz[v], getRt(v, 0);
dfz(rt);
}
vis[u] = 0;
}
int main(){
init();
while(scanf("%d%d", &n, &k) != EOF){
for(int i = 1; i <= n; ++i) G[i].clear();
for(int i = 1; i <= n; ++i) scanf("%d", &a[i]);
for(int i = 1; i < n; ++i){
int u, v; scanf("%d%d", &u, &v);
G[u].pb(v), G[v].pb(u);
}
rmn = inf, tn = n, getRt(1, 0);
p1 = p2 = inf, dfz(rt);
if(p1 == inf) printf("No solution\n");
else printf("%d %d\n", p1, p2);
}
return 0;
}