题目大意:
给出一棵树, 顶点数 n <= 100000
每个点有一个权值
给出K
询问这棵树中两点路径上的点的权值乘积对1e6 + 3取模之后等于K的路径的两个端点形成的点对中字典序最小的
大致思路:
点分治第二题
用f[i]表示从当前的子树根开始到某个点的路径乘积为i的点的最小标号
于是对于每一次分治, 去掉重心x之后, 依次处理每一棵子树来使得f的来源不一样
注意要预处理逆元
代码如下:
Result : Accepted Memory : 24420 KB Time : 2823 ms
/*
* Author: Gatevin
* Created Time: 2015/10/14 9:59:26
* File Name: Sakura_Chiyo.cpp
*/
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;
#define maxn 100010
struct Edge
{
int u, v, nex;
Edge(){}
Edge(int _u, int _v, int _nex)
{
u = _u, v = _v, nex = _nex;
}
};
Edge edge[maxn << 1];
int head[maxn];
int E;
int n;
void add_Edge(int u, int v)
{
edge[++E] = Edge(u, v, head[u]);
head[u] = E;
}
int del[maxn];
int root;
int mx[maxn];
int size[maxn];
int mi;
int N;
lint K;
lint w[maxn];
const lint mod = 1e6 + 3;
pair<int, int> ans;
lint rev[1000010];
void getRev()
{
rev[1] = 1;
for(lint i = 2; i < mod; i++)
rev[i] = (mod - mod / i) * rev[mod % i] % mod;
}
void dfs_size(int now, int father)
{
size[now] = 1;
mx[now] = 1;
for(int i = head[now]; i + 1; i = edge[i].nex)
{
int v = edge[i].v;
if(v != father && !del[v])
{
dfs_size(v, now);
size[now] += size[v];
if(size[v] > mx[now]) mx[now] = size[v];
}
}
}
void dfs_root(int r, int now, int father)
{
if(size[r] - size[now] > mx[now]) mx[now] = size[r] - size[now];
if(mx[now] < mi) mi = mx[now], root = now;
for(int i = head[now]; i + 1; i = edge[i].nex)
{
int v = edge[i].v;
if(v != father && !del[v]) dfs_root(r, v, now);
}
}
lint f[1000010];//f[i]表示距离为i的字典序最小的点的标号+pre
const pair<int, int> inf = make_pair(1e9, 1e9);
lint pre;
const lint bit = 1e5;
void get(int now, int father, lint dis, int flag)
{
dis = dis*w[now] % mod;
if(flag == 1)
{
//dis*x % mod = K -> x = K*dis^(mod - 2) % mod
if(f[K*rev[dis] % mod] > pre)
{
int other = (int)(f[K*rev[dis] % mod] - pre);
pair<int, int> p = other < now ? make_pair(other, now) : make_pair(now, other);
if(ans > p) ans = p;
}
}
else
{
if(f[dis] <= pre) f[dis] = pre + now;
else f[dis] = min(f[dis], pre + now);
}
for(int i = head[now]; i + 1; i = edge[i].nex)
{
int v = edge[i].v;
if(!del[v] && v != father)
get(v, now, dis, flag);
}
}
void dfs(int now)
{
mi = N;
dfs_size(now, 0);
dfs_root(now, now, 0);
del[root] = 1;
pre += bit;
f[w[root]] = pre + root;//pre是因为每次清空f数组是不现实的, 于是用pre/bit表示第几次
for(int i = head[root]; i + 1; i = edge[i].nex)
{
int v = edge[i].v;
if(!del[v])
{
get(v, root, 1, 1);
get(v, root, w[root], 0);
}
}
for(int i = head[root]; i + 1; i = edge[i].nex)
{
int v = edge[i].v;
if(!del[v]) dfs(v);
}
}
void solve()
{
ans = inf;
//pre = 0
//memset(f, 0, sizeof(f));利用pre不清0, f就可以不清0了
dfs(1);
if(ans == inf) puts("No solution");
else printf("%d %d\n", ans.first, ans.second);
}
int main()
{
getRev();//预处理逆元
pre = 0;
while(scanf("%d %I64d", &N, &K) != EOF)
{
E = 0;
memset(head, -1, sizeof(head));
memset(del, 0, sizeof(del));
for(int i = 1; i <= N; i++)
scanf("%I64d", &w[i]);
int u, v;
for(int i = 1; i < N; i++)
{
scanf("%d %d", &u, &v);
add_Edge(u, v);
add_Edge(v, u);
}
solve();
}
return 0;
}