D Tree
https://vjudge.net/problem/48318/origin
There is a skyscraping tree standing on the playground of Nanjing University of Science and Technology. On each branch of the tree is an integer (The tree can be treated as a connected graph with N vertices, while each branch can be treated as a vertex). Today the students under the tree are considering a problem: Can we find such a chain on the tree so that the multiplication of all integers on the chain (mod 10 6 + 3) equals to K?
Can you help them in solving this problem?
题意:给你一颗树,每个节点有一个权值val,问是否存在一个点对(a,b)使得路径上所有节点的权值的乘积等于k,如果存在,输出字典序最小的(a,b)
思路:点分治,在处理某个重心时,先算出它的一颗子树上的所有节点到根节点的距离的乘积,然后寻找一个x使得
d
i
s
[
i
]
∗
x
=
k
∗
v
a
l
[
u
]
dis[i]*x = k*val[u]
dis[i]∗x=k∗val[u],其中dis[i]为某一个乘积,需要找到另一颗子树上的一条乘积为x的路径,使得乘积为k*val[u](因为根节点乘了两遍,即
x
=
k
∗
v
a
l
[
u
]
∗
i
n
v
[
d
i
s
[
i
]
]
x = k*val[u]*inv[dis[i]]
x=k∗val[u]∗inv[dis[i]],一开始把所有逆元预处理出来,然后查找是否在其他子树出现过x即可,出现过就对点对进行更新
#include<bits/stdc++.h>
#define MAXN 100010
#define INF 0x3f3f3f3f
using namespace std;
const int MOD = 1e6+3;
int inv[MOD+5];
inline void getInv()
{
inv[1] = 1;
for(int i = 2;i < MOD;++i)
inv[i] = 1ll*(MOD-MOD/i)*inv[MOD%i]%MOD;
}
int head[MAXN],tot;
struct edge
{
int v,nxt;
}edg[MAXN << 1];
inline void addedg(int u,int v)
{
edg[tot].v = v;
edg[tot].nxt = head[u];
head[u] = tot++;
}
int n,val[MAXN],k;
int mx,root,Size,sz[MAXN];
bool vis[MAXN];
inline void getroot(int u,int f)
{
sz[u] =1;
int v,mson = 0;
for(int i = head[u];i != -1;i = edg[i].nxt)
{
v = edg[i].v;
if(v == f || vis[v])
continue;
getroot(v,u);
sz[u] += sz[v];
mson = max(mson,sz[v]);
}
mson = max(mson,Size-sz[u]);
if(mson < mx)
mx = mson,root = u;
}
int ans1,ans2;
int dis[MAXN],id[MAXN],cnt;
int visdis[MOD+5],visid[MOD+5],cc,cdis[MAXN];
inline void getdis(int u,int f,int d)
{
int w = 1ll * d * val[u]%MOD;
dis[++cnt] = w,id[cnt] = u;
int v;
for(int i = head[u];i != -1;i = edg[i].nxt)
{
v = edg[i].v;
if(v == f || vis[v])
continue;
getdis(v,u,w);
}
}
inline void solve(int u,int ssize)
{
cc = 0,vis[u] = 1;
visdis[val[u]%MOD] = 1;
visid[val[u]%MOD] = u;
cdis[++cc] = val[u]%MOD;
int v;
for(int i = head[u];i != -1;i = edg[i].nxt)
{
v = edg[i].v;
if(vis[v]) continue;
cnt = 0;
getdis(v,v,val[u]%MOD);
for(int i = 1;i <= cnt;++i)
{
int tmp = 1ll*inv[dis[i]]*k%MOD*val[u]%MOD;
if(visdis[tmp])
{
int uu = visid[tmp],vv = id[i];
if(uu > vv)
swap(vv,uu);
if(uu < ans1 || (uu == ans1 && vv < ans2))
ans1 = uu,ans2 = vv;
}
}
for(int i = 1;i <= cnt;++i)
{
if(!visdis[dis[i]])
visdis[dis[i]] = 1,visid[dis[i]] = id[i],cdis[++cc] = dis[i];
else if(id[i] < visid[dis[i]])
visid[dis[i]] = id[i];
}
}
for(int i = 1;i <= cc;++i)
visdis[cdis[i]] = 0;
for(int i = head[u];i != -1;i = edg[i].nxt)
{
v = edg[i].v;
if(vis[v]) continue;
Size = sz[v] < sz[u]?sz[v]:ssize-sz[u];
mx = INF;
getroot(v,v);
solve(root,Size);
}
}
inline void init()
{
tot = 0,mx = INF,Size = n,ans1 = ans2 = INF;
memset(vis,false,sizeof(bool) * (n+1));
memset(head,-1,sizeof(int)*(n+1));
}
int main()
{
getInv();
while(~scanf("%d%d",&n,&k))
{
init();
for(int i = 1;i <= n;++i)
scanf("%d",&val[i]);
int u,v;
for(int i = 1;i < n;++i)
{
scanf("%d%d",&u,&v);
addedg(u,v),addedg(v,u);
}
getroot(1,1);
solve(root,Size);
if(ans1 == INF)
printf("No solution\n");
else
printf("%d %d\n",ans1,ans2);
}
return 0;
}