【题目描述】
给定一棵 n 个点的树,每个点有权值 Vi,问是否存在一条路径使得路径上所有点的权值乘积mod(10^6 + 3)为 K,输出路径的首尾标号,若有多解,输出字典序最小的解。【Sample Input】
(多组数据。每组第一行两个数 n,K;第二行 n 个数,表示vi,接下来 n-1 行每行两个数x,y表示一条边)
5 60
2 5 2 3 3
1 2
1 3
2 4
2 5
5 2
2 5 2 3 3
1 2
1 3
2 4
2 5【Sample Output】
(输出两个数 a,b( a< b ),无解输出”No solution”)
3 4
No solution【数据范围】
对于100%的数据,有1≤n≤10^5,0≤K≤10^6+2,1≤vi ≤10^6+2
【题解】树分治+乘法逆元
这是一道裸的树分治。先预处理逆元,在树分治以x为根的时候,保存已处理过的子树上的路径乘积,对于当前路径乘积为s,设s的逆元为ine[s],只需判断保存的信息中是否有K*ine[s] {即k/x,因取模无法做除法}。注意保存字典序最小。
#include <cstdio>
#include <iostream>
#include <cstring>
#define mo 1000003
#define N 100005
struct edge{ int to,nxt;}e[N<<1];
int n,k,tot,mx,rt,cnt,m,st,en,ine[N*10],fi[N],s[N],
a[N][3],dis[N],w[N],g[N*10],dep[N];
bool flag,bo[N],f[N*10];
void add(int u,int v){ e[++cnt].to=v;e[cnt].nxt=fi[u];fi[u]=cnt;}
int qsm(int x,int y)
{
if (y==1) return x;
int p=qsm(x,y>>1);
p=(int)(1ll*p*p%mo);
if (y&1) return (int)(1ll*p*x%mo);
else return p;
}
void findrt(int x,int fa)
{
s[x]=1;int mxx=0;
for (int i=fi[x];i;i=e[i].nxt)
if (!bo[e[i].to] && e[i].to!=fa)
{
findrt(e[i].to,x);
s[x]+=s[e[i].to];
mxx=std::max(mxx,s[e[i].to]);
}
mxx=std::max(mxx,tot-s[x]);
if (mxx<mx) rt=x,mx=mxx;
}
void dfs1(int x,int fa)
{
a[++m][0]=x;a[m][1]=dis[x];
a[m][2]=dep[x];
for (int i=fi[x];i;i=e[i].nxt)
if (!bo[e[i].to] && e[i].to!=fa)
{
dis[e[i].to]=(int)(1ll*dis[x]*w[e[i].to]%mo);
dep[e[i].to]=dep[x]+1;dfs1(e[i].to,x);
}
}
bool work(int x)
{
memset(f,false,sizeof(f));
f[w[x]%mo]=true;g[w[x]%mo]=x;
for (int i=fi[x];i;i=e[i].nxt)
if (!bo[e[i].to])
{
m=0;dis[e[i].to]=w[e[i].to];
dep[e[i].to]=1;dfs1(e[i].to,x);
for (int j=1;j<=m;++j)
{
int t=(int)(1ll*k*ine[a[j][1]]%mo);
if (f[t])
{
int st1=g[t],en1=a[j][0];
if (st1>en1) std::swap(st1,en1);
if (!flag || st1<st || st1==st && en1<en)
{ st=st1;en=en1;}
flag=true;
}
}
for (int j=1;j<=m;++j)
{
int p=(int)(1ll*a[j][1]*w[x]%mo);
if (!f[p]) g[p]=a[j][0];
else g[p]=std::min(g[p],a[j][0]);
f[p]=true;
}
}
return false;
}
void dfs(int x,int y)
{
work(x);
for (int i=fi[x];i;i=e[i].nxt)
if (!bo[e[i].to])
{
if (s[e[i].to]>s[x]) s[e[i].to]=y-s[x];
mx=tot=s[e[i].to];
findrt(e[i].to,x);bo[rt]=true;
dfs(rt,s[e[i].to]);
}
}
int main()
{
scanf("%d%d\n",&n,&k);
for (int i=1;i<=n;++i) scanf("%d",&w[i]);
for (int i=1;i<n;++i)
{
int u,v;scanf("%d%d\n",&u,&v);
add(u,v);add(v,u);
}
for (int i=1;i<=mo;++i)
ine[i]=qsm(i,mo-2);
mx=tot=n;findrt(1,0);bo[rt]=true;
flag=false;dfs(rt,n);
if (flag)
{
if (st>en) std::swap(st,en);
printf("%d %d\n",st,en);
}
else printf("No solution");
return 0;
}