题意:
一棵树,每个点有一个权值,求一对点的路径上的点的乘积取模后恰好为k,输出标号最小的点对
思路:
点分治
首先我们要求的是 (a*b)%mod=k 如果我们知道a,那么b就是(k*inv[a])%mod
而我们点分治的时候,先求出重心,然后处理出一颗子树上的点到root的乘积,然后乘root取模找逆元,看看map里面有没有,再把这些数存到map里。这样处理完一颗子树,然后把这棵子树的东西放进map,就能防止出现不合法的情况(路径没经过重心)
先求出所有逆元,然后重复上面步骤即可
错误及反思:
忘了换点对的大小结果wa了几发。。。
代码:
#include<bits/stdc++.h>
#define fi first
#define se second
using namespace std;
const long long mod = 1e6 + 3;
const int N = 100100;
struct EDGE{
int to,next;
}e[N*2];
int first[N],n,tot,si[N],maxn[N],k;
bool did[N];
long long ny[1001000],val[N];
map<long long,int> m;
pair<int,int> ans;
vector<pair<long long,int> > v;
void addedge(int x,int y){
e[tot].to=y;
e[tot].next=first[x];
first[x]=tot++;
e[tot].to=x;
e[tot].next=first[y];
first[y]=tot++;
}
void dfs_size(int now,int fa){
si[now]=1;
maxn[now]=0;
for(int i=first[now];i!=-1;i=e[i].next)
if(e[i].to!=fa&&!did[e[i].to]){
dfs_size(e[i].to,now);
si[now]+=si[e[i].to];
maxn[now]=max(maxn[now],si[e[i].to]);
}
}
void dfs_root(int now,int fa,int& root,int& nu,int t){
int MA=max(maxn[now],si[t]-si[now]);
if(MA<nu){
nu=MA;
root=now;
}
for(int i=first[now];i!=-1;i=e[i].next)
if(e[i].to!=fa&&!did[e[i].to])
dfs_root(e[i].to,now,root,nu,t);
}
void dfs2(int now,int fa,long long tlen){
tlen%=mod;
for(int i=first[now];i!=-1;i=e[i].next)
if(e[i].to!=fa&&!did[e[i].to])
dfs2(e[i].to,now,tlen*val[e[i].to]);
v.push_back({tlen,now});
}
void solve(int now){
int root,nu=1e9;
m.clear();
dfs_size(now,-1);
dfs_root(now,-1,root,nu,now);
did[root]=true;
m[1]=root;
for(int i=first[root];i!=-1;i=e[i].next){
if(!did[e[i].to]){
v.clear();
dfs2(e[i].to,root,val[e[i].to]);
for(int j=0;j<v.size();j++){
long long ty=v[j].fi*val[root]%mod;
int tx=m[ny[ty]*k%mod];
if(tx){
int minn=min(v[j].se,tx);
int maxx=max(v[j].se,tx);
if(ans.fi==0){ ans.fi=minn,ans.se=maxx;}
else{
if(ans.fi>minn){
ans.fi=minn;
ans.se=maxx;
}
else if(ans.fi==minn)
ans.se=min(ans.se,maxx);
}
}
}
for(int j=0;j<v.size();j++){
if(m[v[j].fi]==0)
m[v[j].fi]=v[j].se;
else m[v[j].fi]=min(v[j].se,m[v[j].fi]);
}
}
}
for(int i=first[root];i!=-1;i=e[i].next)
if(!did[e[i].to])
solve(e[i].to);
}
void init(){
memset(first,-1,sizeof(first));
tot=0;
ans.fi=0;
memset(did,false,sizeof(did));
}
int main(){
ny[1] = 1;
for(int i = 2; i < 1000100; i ++){
ny[i] = (mod - mod / i) * 1ll * ny[mod % i] % mod;
}
while(scanf("%d%d",&n,&k)!=EOF){
init();
for(int i=1;i<=n;i++) scanf("%lld",&val[i]);
for(int i=0,u,v;i<n-1;i++){
scanf("%d%d",&u,&v);
addedge(u,v);
}
solve(1);
if(ans.fi) printf("%d %d\n",ans.fi,ans.se);
else printf("No solution\n");
}
}