【湖南集训 】谈笑风生题解 线段树合并

题目:https://www.luogu.org/problemnew/show/P3899
蒟蒻第一道1A的线段树合并,写个博客纪念一下。
读完题,发现题目要求我们对于给定的一个a,求出有多少组b,c,满足a和b是c的祖先,且b和a距离不超过k。
分析一下,a和b都是c的祖先,那么a和b一定出现在一条链上,换言之,a一定是b的祖先节点,或在b的子树中。那么,我们不妨对于这两种情况分开来考虑。首先,对于b是a祖先的情况,显然,c在a的子树中。那么,这种情况对答案的贡献就是min(k , dep[a] - 1) * (sz[a] - 1) 。(这个不需要我解释了吧)(对了,根节点初始深度赋为1,不然dep[a]就不用减一)
接下来,我们来考虑另一种情况,即对于a是b的祖先的情况。对于这种情况,我们可以发现,
∀b , (b在a的子树中,dep[b] - dep[a] <= k),它对答案的贡献为sz[b] - 1。那么,我们现在问题只剩下对于一个点,如何维护它子树中所有点的sz了。显然,对于每一个点,如果暴力更新它的所有子树中的节点,时间复杂度和空间复杂度都是不可接受的。但是,我们考虑到,每个点都跟它的子节点有一部分维护的sz相同,那么,我们会考虑线段树合并。(不会的同志这里看:https://blog.csdn.net/weixin_43790474
那么,这道题就相当愉快地解决了。好吧,还剩下一些细节没说清。(虽然我觉得这个自己想都可以)我们清楚,最终查询的时候是在一个深度区间中查询,那么,我们毫无疑问,以每一个点建一颗线段树,它的下标为深度,维护的是当前深度下的sz的和。(由于是由子节点一个一个合并而来,所以可以保证在当前的线段树中,只保存有它的子节点的sz)还不能理解的话看看代码吧:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 400010;
typedef long long ll;
inline char get_char()
{
	static char buf[100000] , *p1 = buf , *p2 = buf;
	if (p1 == p2)
	{
		p2 = (p1 = buf) + fread(buf , 1 , 100000 , stdin);
		if (p1 == p2)
		{
			return EOF;
		}
	}
	return *p1++;
}
inline int read()
{
	int res;
	char ch;
	int f = 1;
	while (!isdigit(ch = get_char()))
	{
		if (ch == '-')
		{
			f = -1;
		}
	}
	res = ch - '0';
	while (isdigit(ch = get_char()))
	{
		res = res * 10 + ch - '0';
	}
	return res * f;
}
int n , m , cnt , maxdep , tot;
struct edge
{
	int v , next;
}E[2 * maxn];
int len , head[maxn];
void add(int u , int v)
{
	E[len].v = v , E[len].next = head[u];
	head[u] = len++;
}
int dep[maxn] , tid[maxn] , ls[20 * maxn] , rs[20 * maxn] , rt[maxn];
ll sz[maxn] , sum[20 * maxn];
void update(int &id , int l , int r , int pos , ll val , int last)
{
	id = ++tot;
	sum[id] = sum[last] + val;
	ls[id] = ls[last];
	rs[id] = rs[last];
	if (l == r)
	{
		return;
	}
	int mid = (l + r) >> 1;
	if (pos <= mid)
	{
		update(ls[id] , l , mid  , pos , val , ls[last]);
	}
	else
	{
		update(rs[id] , mid + 1 , r  , pos , val , rs[last]);
	}
}
ll query(int s , int t , int l , int r , int x , int y)
{
	if (!s && !t || l > r)
	{
		return 0;
	}
	if (l >= x && r <= y)
	{
		return sum[t] - sum[s];
	}
	int mid = (l + r) >> 1;
	ll res = 0;
	if (x <= mid)
	{
		res += query(ls[s] , ls[t] , l , mid , x , y);
	}
	if (y > mid)
	{
		res += query(rs[s] , rs[t] , mid + 1 , r , x , y);
	}
	return res;
}
void dfs(int u , int fa)
{
	tid[u] = ++cnt;
	sz[u] = 1;
	for (int i = head[u]; ~i; i = E[i].next)
	{
		int v = E[i].v;
		if (v != fa)
		{
			dep[v] = dep[u] + 1;
			maxdep = max(maxdep , dep[v]);
			dfs(v , u);
			sz[u] += sz[v];
		}
	}
}
void dfs2(int u , int fa)
{
	update(rt[tid[u]] , 1 , maxdep , dep[u] , sz[u] - 1 , rt[tid[u] - 1]);
	for (int i = head[u]; ~i; i = E[i].next)
	{
		int v = E[i].v;
		if (v != fa)
		{
			dfs2(v , u);
		}
	}
}
ll ans;
int main()
{
	//freopen("data.in" , "r" , stdin);
	memset(head , -1 , sizeof(head));
	n = read() , m = read();
	for (int i = 1; i < n; i++)
	{
		int u = read() , v = read();
		add(u , v);
		add(v , u);
	}
	dep[1] = 1;
	dfs(1 , -1);
	dfs2(1 , -1);
	for (int i = 1; i <= m; i++)
	{
		int a = read() , k = read();
		int top = min(k , dep[a] - 1);
		ans = 1ll * top * (sz[a] - 1);
		if (sz[a] == 1)
		{
			printf("%lld\n" , ans);
			continue;
		}
		ans += query(rt[tid[a]] , rt[tid[a] + sz[a] - 1] , 1 , maxdep , dep[a] + 1 , min(maxdep , dep[a] + k));
		printf("%lld\n" , ans);
	}
}

看完了确定不点个赞?

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值