题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5977
题目要求在树上找到一条链使得这条链上的点的乘积模mod等于k,求链首尾字典序最小的一条
看到题目就能知道是一个点分治的题目,将树按照重心分治之后,就是要统计以重心为根的子树中,过树根的mod为k的链字典序最小的一条,这里的统计必需是在时间复杂度O(n) 以下才能过
代码:
#pragma comment(linker,"/STACK:102400000,102400000")
#include <bits/stdc++.h>
using namespace std;
const int maxn = 100000 + 50,mod = 1000000 + 3,INF = 0x3f3f3f3f;
typedef long long LL;
struct Edge{
int v,pre;
}edges[maxn * 2];
int head[maxn],tot;
void Edge_Init(){memset(head,-1,sizeof head),tot = 0;}
void Insert_Edge(int u,int v){
edges[tot].v = v;
edges[tot].pre = head[u];
head[u] = tot++;
}
int n;
LL k;
LL node_value[maxn];
LL inv[mod];
int _hash[mod],flag[mod];
void Init_Inv(){
inv[1] = 1;
for(int i = 2;i < mod;++i) inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
int size[maxn],center,center_size;
int vis[maxn];
void getInfo(int rt,int fa){
size[rt] = 1;int v;
for(int i = head[rt];~i;i = edges[i].pre){
v = edges[i].v;
if(v == fa || vis[v]) continue;
getInfo(v,rt);
size[rt] += size[v];
}
}
void getCenter(int rt,int fa,const int tot_size){
int tmp_size = tot_size - size[rt],v;
for(int i = head[rt];~i;i = edges[i].pre){
v = edges[i].v;
if(v == fa || vis[v]) continue;
getCenter(v,rt,tot_size);
tmp_size = max(tmp_size,size[v]);
}
if(tmp_size < center_size) center = rt,center_size = tmp_size;
}
int ans_a,ans_b;
void Update_Ans(int a,int b){
if(a > b) swap(a,b);
if(ans_a > a || ans_a == a && ans_b > b) ans_a = a,ans_b = b;
}
LL path_value[maxn];
int path_value_node[maxn],path_cnt;
void solve_dfs(int rt,int fa,LL pre_value){
path_value[path_cnt] = pre_value * node_value[rt] % mod,path_value_node[path_cnt++] = rt;
int v;LL tmp = path_value[path_cnt - 1];
for(int i = head[rt];~i;i = edges[i].pre){
v = edges[i].v;
if(v == fa || vis[v]) continue;
solve_dfs(v,rt,tmp);
}
}
int lable;
void solve(int rt){
//得到重心
center_size = INF;
getInfo(rt,0),getCenter(rt,0,size[rt]);
rt = center;
//求解子树
vis[rt] = 1;
int v,i,j;
LL next_value;
for(i = head[rt];~i;i = edges[i].pre){
v = edges[i].v;
if(vis[v]) continue;
path_cnt = 0;
solve_dfs(v,rt,1);
for(j = 0;j < path_cnt;++j){
if(path_value[j] * node_value[rt] % mod == k) Update_Ans(rt,path_value_node[j]);
next_value = k * inv[path_value[j] * node_value[rt] % mod] % mod;
// printf("v %d end %d v %lld next_v %lld\n",v,path_value_node[j],path_value[j],next_value);
// printf("%lld %lld %lld %lld \n",k,path_value[j],node_value[rt],next_value * path_value[j] * node_value[rt] % mod);
if(flag[next_value] == lable) Update_Ans(_hash[next_value],path_value_node[j]);
}
for(int j = 0;j < path_cnt;++j){
if(flag[path_value[j]] != lable || (flag[path_value[j]] == lable && path_value_node[j] < _hash[path_value[j]]))
flag[path_value[j]] = lable,_hash[path_value[j]] = path_value_node[j];
}
}
lable++;
for(i = head[rt];~i;i = edges[i].pre){
v = edges[i].v;
if(vis[v]) continue;
solve(v);
}
}
int main(){
// freopen("read.txt","r",stdin);
Init_Inv();
while(scanf("%d %lld",&n,&k) != EOF){
Edge_Init();
for(int i = 1;i <= n;++i) scanf("%lld",&node_value[i]);
for(int i = 1;i < n;++i){
int u,v;scanf("%d %d",&u,&v);
Insert_Edge(u,v),Insert_Edge(v,u);
}
memset(vis,0,sizeof vis);
memset(flag,0,sizeof flag);
ans_a = ans_b = INF;lable = 1;
solve(1);
if(ans_a != INF) printf("%d %d\n",ans_a,ans_b);
else printf("No solution\n");
}
return 0;
}