给定一颗树,点上有权值,问其上边是否有一条链,链上的点的权值的积模(10^6+3)恰好等于k。如果有解输出字典序最小的一组。
树的分治,先找重心,然后判断过重心的链,然后判断在每一个子树上的,即不过重心的链
#pragma comment(linker,"/STACK:102400000,102400000")
#include <cstdio>
#include <cstring>
#include <iostream>
#include <cmath>
#include <map>
using namespace std;
const int mod=1000003;
struct Node {
int fe,v,num;
bool visited;
};
struct Edge {
int t,ne;
};
int n,p,ansa,ansb,kk,pp,curn;
Node a[100000];
Edge b[200000];
int ni[mod];
pair<int,int> cur[100000];
int total[mod];
int used[mod]={0};
inline void updateans(int a,int b) {
//printf("updateans: %d %d\n",a,b);
if (a>b) swap(a,b);
if (ansa==-1||a<ansa) {
ansa=a;
ansb=b;
} else if (a==ansa&&b<ansb) {
ansb=b;
}
}
inline void update(int &a,int b) {
if (a>b) a=b;
}
inline void putedge(int x,int y) {
b[p].t=y;
b[p].ne=a[x].fe;
a[x].fe=p++;
}
inline int in() {
char c=getchar();
while (c<'0'||c>'9') c=getchar();
int ans=0;
while (c>='0'&&c<='9') {
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}
void calval(int i,int v) {
if (a[i].visited) return;
a[i].visited=true;
cur[curn++]=make_pair(v,i);
int tmp=(long long)kk*ni[v]%mod;
//printf("%d %d %d %d\n",tmp,v,used[tmp],total[tmp]);
if (used[tmp]==pp) {
updateans(total[tmp],i);
}
for (int j=a[i].fe;j!=-1;j=b[j].ne) {
calval(b[j].t,(long long)v*a[b[j].t].v%mod);
}
a[i].visited=false;
}
int getnumber(int i) {
if (a[i].visited) return 0;
a[i].visited=true;
int ans=1;
for (int j=a[i].fe;j!=-1;j=b[j].ne) {
ans+=getnumber(b[j].t);
}
a[i].num=ans;
a[i].visited=false;
return ans;
}
int center,nn;
void dfs(int i,int h) {
if (a[i].visited) return;
//printf("%d %d %d\n",i,h,nn);
a[i].visited=true;
int maxn=h;
for (int j=a[i].fe;j!=-1;j=b[j].ne) {
if (!a[b[j].t].visited) {
dfs(b[j].t,h+a[i].num-a[b[j].t].num);
if (maxn<a[b[j].t].num) maxn=a[b[j].t].num;
}
}
if (maxn<nn) {
nn=maxn;
center=i;
}
a[i].visited=false;
}
int findcenter(int i) {
int n=getnumber(i);
//printf("Total %d nodes\n",n);
center=-1;
nn=200000;
dfs(i,0);
return center;
}
void getans(int i) {
i=findcenter(i);
a[i].visited=true;
//printf("Center: %d\n",i);
pp++;
total[a[i].v]=i;
used[a[i].v]=pp;
for (int j=a[i].fe;j!=-1;j=b[j].ne) {
if (a[b[j].t].visited) continue;
curn=0;
calval(b[j].t,a[b[j].t].v);
for (int it=0;it<curn;it++) {
int tmp=(long long)cur[it].first*a[i].v%mod;
//printf("==%d %d %d\n",tmp,pp,used[tmp]);
if (used[tmp]!=pp) {
used[tmp]=pp;
total[tmp]=cur[it].second;
} else update(total[tmp],cur[it].second);
}
}
for (int j=a[i].fe;j!=-1;j=b[j].ne)
if (!a[b[j].t].visited) {
getans(b[j].t);
}
}
int xx,yy;
int egcd(int a,int b) {
int temp,tempx;
if (b==0) {
xx=1;yy=0;
return a;
}
temp=egcd(b,a%b);
tempx=xx;
xx=yy;
yy=tempx-a/b*yy;
return temp;
}
int main() {
int i,j,x,y;
pp=0;
for (i=1;i<mod;i++) {
egcd(i,mod);
ni[i]=(xx+mod)%mod;
}
while (scanf("%d%d",&n,&kk)!=EOF) {
for (i=0;i<n;i++) {
a[i].fe=-1;
a[i].visited=false;
}
p=0;
for (i=0;i<n;i++) a[i].v=in();
for (i=1;i<n;i++) {
x=in();
y=in();
putedge(x-1,y-1);
putedge(y-1,x-1);
}
ansa=n;
ansb=n;
getans(0);
if (ansa!=n) printf("%d %d\n",ansa+1,ansb+1);
else printf("No solution\n");
}
return 0;
}