题目链接
给
定
一
棵
n
个
点
的
树
,
给
定
m
个
人
(
m
≤
n
)
在
哪
个
点
上
的
信
息
,
每
个
点
可
以
有
任
意
个
人
;
然
后
给
q
个
询
问
,
每
次
问
u
到
v
上
的
路
径
有
的
点
上
编
号
最
小
的
k
(
k
≤
10
)
个
人
(
没
有
那
么
多
人
就
该
有
多
少
人
输
出
多
少
人
)
。
给定一棵n个点的树,给定m个人(m≤n)在哪个点上的信息,每个点可以有任意个人;然后给q个询问,每次问u到v上的路径有的点上编号最小的k(k≤10)个人(没有那么多人就该有多少人输出多少人)。
给定一棵n个点的树,给定m个人(m≤n)在哪个点上的信息,每个点可以有任意个人;然后给q个询问,每次问u到v上的路径有的点上编号最小的k(k≤10)个人(没有那么多人就该有多少人输出多少人)。
解法一:树上倍增法:
解题思路:观察那个数据,因为k的大小只有10,那么我们对于每个点就可以暴力求出前k小的数,用倍增的思想:
在
[
l
,
r
]
区
间
内
部
前
k
小
数
就
是
[
l
,
l
+
2
i
−
1
]
和
[
l
+
2
i
,
r
]
区
间
内
的
第
k
小
值
在[l,r]区间内部前k小数就是[l,l+2^i-1]和[l+2^i,r]区间内的第k小值
在[l,r]区间内部前k小数就是[l,l+2i−1]和[l+2i,r]区间内的第k小值
这里合并有个小技巧:因为每个小区间都是单调的所以我们求第k小的时候我们可以用归并的思想去合并区间的信息
代码
#include<bits/stdc++.h>
#define mid ((l + r) >> 1)
#define Lson rt << 1, l , mid
#define Rson rt << 1|1, mid + 1, r
#define ms(a,al) memset(a,al,sizeof(a))
#define log2(a) log(a)/log(2)
#define _for(i,a,b) for( int i = (a); i < (b); ++i)
#define _rep(i,a,b) for( int i = (a); i <= (b); ++i)
#define for_(i,a,b) for( int i = (a); i >= (b); -- i)
#define rep_(i,a,b) for( int i = (a); i > (b); -- i)
#define lowbit(x) ((-x) & x)
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define INF 0x3f3f3f3f
#define LLF 0x3f3f3f3f3f3f3f3f
#define hash Hash
#define next Next
#define pb push_back
#define f first
#define s second
#define y1 Y
using namespace std;
const int N = 4e5 + 10, mod = 1e9 + 7;
const int maxn = 4e5 + 10;
const long double eps = 1e-5;
const int EPS = 500 * 500;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PII;
typedef pair<ll,ll> PLL;
typedef pair<double,double> PDD;
const int BUF=30000000;
char Buf[BUF],*buf=Buf;
template<typename T> void read(T &a)
{
for(a=0;*buf<48;buf++);
while(*buf>47) a=a*10+ *buf++ -48;
}
template<typename T, typename... Args> void read(T &first, Args& ... args)
{
read(first);
read(args...);
}
int n, m ,q;
vector<int> G[N];
vector<int> pre[N][30];
int fa[N][20], depth[N];
inline vector<int> Merge(vector<int> a, vector<int> b) {//合并前10小的数
vector<int> ans;
int poia = 0, poib = 0;
while(ans.size() < 10 && poia < a.size() && poib < b.size()) {
if(a[poia] < b[poib]) ans.push_back(a[poia ++]);
else ans.push_back(b[poib ++]);
}
while(ans.size() < 10 && poia < a.size()) ans.push_back(a[poia ++]);
while(ans.size() < 10 && poib < b.size()) ans.push_back(b[poib ++]);
return ans;
}
inline void dfs(int u, int f) {
fa[u][0] = f;
for(int i = 1; i <= 17; ++ i)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for(int j = 1; j <= 17; ++ j) {
pre[u][j] = Merge(pre[u][j - 1],pre[fa[u][j - 1]][j - 1]);
}
depth[u] = depth[f] + 1;
for(auto it : G[u]) {
if(it == f)continue;
dfs(it,u);
}
}
inline int LCA(int u, int v) {
if(depth[u] > depth[v]) swap(v,u);
int delta = depth[v] - depth[u];
for(int i = 0; i <= 17; ++ i)
if(delta >> i & 1) v = fa[v][i];
if(u == v) return v;
for(int i = 17; i >= 0; -- i)
if(fa[u][i] != fa[v][i])
u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
inline vector<int> slove(int u, int v) {
vector<int> res;
int delta = depth[u] - depth[v];
for(int j = 0; j <= 17 && u; ++ j)
if(delta >> j & 1)
res = Merge(res,pre[u][j]), u = fa[u][j];
return res;
}
int main()
{
fread(Buf,1,BUF,stdin);
read(n,m,q);
for(int i = 0; i < n - 1; ++ i) {
int l, r;
read(l,r);
G[l].push_back(r);
G[r].push_back(l);
}
for(int i = 1; i <= m; ++ i) {
int x;
read(x);
pre[x][0].push_back(i);
}
for(int i = 1; i <= n; ++ i) sort(pre[i][0].begin(),pre[i][0].end());
dfs(1,0);
while(q --) {
int u, v, a;
read(u,v,a);
int lca = LCA(u,v);
vector<int> ans, res;
ans = slove(u,lca);
res = slove(v,lca);
res = Merge(ans,res);
res = Merge(res,pre[lca][0]); //因为LCA还没被计算过
//..........................................................
if(!res.size()) printf("0\n");
else cout << min((int)res.size(),a) << " ";
for(int i = 0; i < min(a,(int)res.size()); ++ i) {
printf("%d",res[i]);
if(i < min(a,(int)res.size()) - 1) printf(" ");
else printf("\n");
}
}
return 0;
}
解法2:主席树,因为树上信息具有可加性所以我们可以按照dfs序对这颗树建立一个主席树就是每个点到根节点的路径区间建立主席树,多次插入,那么要求u和v之间的路径信息的话那么就要 [ r o o t , u ] + [ r o o t , v ] − [ r o o t , l c a ] − [ r o o t , f a [ l c a ] ] [root,u]+[root,v]-[root,lca]-[root,fa[lca]] [root,u]+[root,v]−[root,lca]−[root,fa[lca]],这个可以有倍增去维护
代码:
#include<bits/stdc++.h>
#define mid ((l + r) >> 1)
#define Lson rt << 1, l , mid
#define Rson rt << 1|1, mid + 1, r
#define ms(a,al) memset(a,al,sizeof(a))
#define log2(a) log(a)/log(2)
#define _for(i,a,b) for( int i = (a); i < (b); ++i)
#define _rep(i,a,b) for( int i = (a); i <= (b); ++i)
#define for_(i,a,b) for( int i = (a); i >= (b); -- i)
#define rep_(i,a,b) for( int i = (a); i > (b); -- i)
#define lowbit(x) ((-x) & x)
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define INF 0x3f3f3f3f
#define LLF 0x3f3f3f3f3f3f3f3f
#define hash Hash
#define next Next
#define pb push_back
#define f first
#define s second
#define y1 Y
using namespace std;
const int N = 2e5 + 10, mod = 1e9 + 7;
const int maxn = 4e5 + 10;
const long double eps = 1e-5;
const int EPS = 500 * 500;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PII;
typedef pair<ll,ll> PLL;
typedef pair<double,double> PDD;
template<typename T> void read(T &x)
{
x = 0;char ch = getchar();ll f = 1;
while(!isdigit(ch)){if(ch == '-')f*=-1;ch=getchar();}
while(isdigit(ch)){x = x*10+ch-48;ch=getchar();}x*=f;
}
template<typename T, typename... Args> void read(T &first, Args& ... args)
{
read(first);
read(args...);
}
int n, m, q, maxv;
//.......................建图
struct Node {
/* data */
int to, nxt;
}edge[N];
int head[N], cnt;
vector<int> node[N];
inline void add(int from, int to) {
edge[cnt] = (Node){to,head[from]};
head[from] = cnt ++;
}
//.....................主席树
struct Tree {
int lson, rson, cnt;
}tr[N * 40];
int root[N];
int idx = 0;
inline int build(int l, int r) {
int now = ++ idx;
if(l == r) return now;
tr[now].lson = build(l,mid), tr[now].rson = build(mid+1,r);
return now;
}
inline int insert(int pre, int l, int r, int k) {
int poi = ++ idx;
tr[poi] = tr[pre];
if(l == r) {
tr[poi].cnt ++;
return poi;
}
if(k <= mid) tr[poi].lson = insert(tr[pre].lson,l,mid,k);
else tr[poi].rson = insert(tr[pre].rson,mid+1,r,k);
tr[poi].cnt = tr[tr[poi].lson].cnt + tr[tr[poi].rson].cnt;
return poi;
}
inline int query(int lpoi, int rpoi, int lcapoi, int falcapoi, int l, int r, int k) {
if(l == r)
return l;
int Eps = tr[tr[lpoi].lson].cnt + tr[tr[rpoi].lson].cnt - tr[tr[lcapoi].lson].cnt - tr[tr[falcapoi].lson].cnt;
if(k <= Eps) return query(tr[lpoi].lson,tr[rpoi].lson,tr[lcapoi].lson,tr[falcapoi].lson,l,mid,k);
else return query(tr[lpoi].rson,tr[rpoi].rson,tr[lcapoi].rson,tr[falcapoi].rson,mid+1,r,k-Eps);
}
//...................倍增求LCA
int fa[N][20], depth[N];
inline void dfs(int u, int f) {
if(node[u].size())
{
root[u] = insert(root[f],1,maxv+1,node[u][0]);
for(int i = 1; i < node[u].size(); ++ i)
root[u] = insert(root[u],1,maxv+1,node[u][i]);
}
else root[u] = root[f];
fa[u][0] = f; depth[u] = depth[f] + 1;
for(int i = 1; i <= 17; ++ i)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for(int i = head[u]; ~i; i = edge[i].nxt) {
int v = edge[i].to;
if(v == f) continue;
dfs(v,u);
}
}
//.....................................
inline int LCA(int u, int v) {
if(depth[u] > depth[v]) swap(v,u);
int delta = depth[v] - depth[u];
for(int i = 0; i <= 17; ++ i)
if(delta >> i & 1) v = fa[v][i];
if(u == v) return v;
for(int i = 17; i >= 0; -- i)
if(fa[u][i] != fa[v][i])
u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
//.......................................
inline void debug(int rt, int l, int r) {
cout << tr[tr[rt].lson].cnt << " ";
cout << tr[tr[rt].rson].cnt << "\n";
cout << l << " " << r << endl;
cout << "-------------" << endl;
if(l == r)
{
if(tr[rt].cnt)
cout << l << " debug" << endl;
return;
}
debug(tr[rt].lson,l,mid);
debug(tr[rt].rson,mid+1,r);
}
//.......................................
int ans[N], poi;
//.......................................
int main() {
ms(head,-1);
read(n,m,q);
maxv = max(n,m);
for(int i = 0; i < n - 1; ++ i) {
int l, r;
read(l,r);
add(l,r), add(r,l);
}
for(int i = 1; i <= m; ++ i)
{
int x;
read(x);
node[x].push_back(i);
}
root[0] = build(1,maxv+1);
dfs(1,0);
while(q --) {
int u,v,a;
read(u,v,a);
int lca = LCA(u,v);
for(int i = 1; i <= min(a,m); ++ i)
{
int tmp = query(root[u],root[v],root[lca],root[fa[lca][0]],1,maxv+1,i);
if(tmp > maxv) continue;
ans[poi ++] = tmp;
}
printf("%d ",poi);
if(poi == 0) puts("");
for(int i = 0; i < poi; ++ i)
printf("%d%c",ans[i]," \n"[i == poi - 1]);
poi = 0;
}
return 0;
}