「LOJ6073」「2017 山东一轮集训 Day5」距离-主席树+树链剖分

Description

给你一棵 n n n 个点的树和一个排列 p p p,边有边权,记 d i s t ( u , v ) dist(u, v) dist(u,v) 表示 u u u v v v 的距离, p a t h ( u , v ) path(u, v) path(u,v) 表示 u u u v v v 路径上所有点组成的集合,现在有 q q q 次询问,每次给出 u i u_i ui , v i v_i vi , k i k_i ki ,问:
∑ i ∈ p a t h d i s t ( p i , k ) \sum_{i\in path} dist(p_i,k) ipathdist(pi,k)
n , q ≤ 2 × 1 0 5 n, q \leq 2 × 10^5 n,q2×105 ,强制在线。时间限制 4 s 4s 4s,空间限制 1 G B 1GB 1GB

Solution

首先可以把路径转化为两者到根再相减。

然后考虑维护 ∑ i ∈ p a t h ( u , r o o t ) d i s ( i , k ) \sum_{i \in path(u,root)}dis(i,k) ipath(u,root)dis(i,k)。首先可以把 d i s dis dis转化为两者深度减去 l c a lca lca深度的两倍。而 l c a lca lca的深度为两者链交的长度。所以用树剖+主席树维护,每次添加一个点 u u u时,把 p u p_u pu到根的路径加 1 1 1。查询时查询点 k k k r o o t root root的权值和即可。

其实就是一个维护一个点集与任意一个点的 l c a lca lca的深度之和的套路。

#include <bits/stdc++.h>
using namespace std;

typedef long long lint;
const int maxn = 200005;

int n, q, type, p[maxn];

struct edge
{
	int to, next, w;
} e[maxn * 2];
int h[maxn], tot, top[maxn], fa[maxn], dfn[maxn], ord[maxn], Time, siz[maxn], dep[maxn], w[maxn], son[maxn];
lint dis[maxn], pre_d[maxn];

int rt[maxn], cnt, lch[maxn * 60], rch[maxn * 60];
lint lzy[maxn * 60], sum[maxn * 60], sum_c[maxn * 60];

inline int gi()
{
	char c = getchar();
	while (c < '0' || c > '9') c = getchar();
	int sum = 0;
	while ('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
	return sum;
}

inline void add(int u, int v, int w)
{
	e[++tot] = (edge) {v, h[u], w}; h[u] = tot;
	e[++tot] = (edge) {u, h[v], w}; h[v] = tot;
}

void dfs1(int u)
{
	siz[u] = 1;
	for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
		if (v != fa[u]) {
			fa[v] = u; dep[v] = dep[u] + 1; w[v] = e[i].w; dis[v] = dis[u] + e[i].w;
			dfs1(v);
			siz[u] += siz[v];
			if (siz[v] > siz[son[u]]) son[u] = v;
		}
}

void dfs2(int u)
{
	ord[dfn[u] = ++Time] = u;
	if (son[u]) top[son[u]] = top[u], dfs2(son[u]);
	for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
		if (v != fa[u] && v != son[u]) top[v] = v, dfs2(v);
}

int lca(int u, int v)
{
	while (top[u] != top[v]) {
		if (dep[top[u]] > dep[top[v]]) u = fa[top[u]];
		else v = fa[top[v]]; 
	}
	return dep[u] < dep[v] ? u : v;
}

#define mid ((l + r) >> 1)

void build(int &s, int l, int r)
{
	s = ++cnt;
	if (l == r) return sum_c[s] = w[ord[l]], void();
	build(lch[s], l, mid);
	build(rch[s], mid + 1, r);
	sum_c[s] = sum_c[lch[s]] + sum_c[rch[s]];
}

void insert(int &s, int l, int r, int x, int y)
{
	++cnt;
	sum[cnt] = sum[s]; sum_c[cnt] = sum_c[s]; lzy[cnt] = lzy[s];
	lch[cnt] = lch[s]; rch[cnt] = rch[s];
	s = cnt;

	if (x <= l && r <= y) return ++lzy[s], sum[s] += sum_c[s], void();
	if (x <= mid) insert(lch[s], l, mid, x, y);
	if (y >= mid + 1) insert(rch[s], mid + 1, r, x, y);

	sum[s] = sum[lch[s]] + sum[rch[s]] + lzy[s] * sum_c[s];
}

pair<lint, lint> operator + (const pair<lint, lint> &a, const pair<lint, lint> &b)
{
	return make_pair(a.first + b.first, a.second + b.second);
}

pair<lint, lint> query(int &s, int l, int r, int x, int y)
{
	if (x <= l && r <= y) return make_pair(sum[s], sum_c[s]);
	pair<lint, lint> res = make_pair(0, 0);
	if (x <= mid) res = res + query(lch[s], l, mid, x, y);
	if (y >= mid + 1) res = res + query(rch[s], mid + 1, r, x, y);
	res.first += res.second * lzy[s];
	return res;
}

void insert(int u)
{
	int k = u;
	rt[k] = rt[fa[u]]; u = p[u];
	while (u) {
		insert(rt[k], 1, n, dfn[top[u]], dfn[u]);
		u = fa[top[u]];
	}
}

lint query(int u, int k)
{
	if (!k) return 0;
	lint res = pre_d[k] + (dep[k] + 1) * dis[u];
	while (u) {
		res -= query(rt[k], 1, n, dfn[top[u]], dfn[u]).first << 1;
		u = fa[top[u]];
	}
	return res;
}

void dfs3(int u)
{
	pre_d[u] = pre_d[fa[u]] + dis[p[u]];
	insert(u);
	for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
		if (v != fa[u]) dfs3(v);
}

int main()
{
	type = gi();
	n = gi(); q = gi();
	for (int i = 1, u, v, w; i < n; ++i) u = gi(), v = gi(), w = gi(), add(u, v, w);
	for (int i = 1; i <= n; ++i) p[i] = gi();

	dfs1(1);
	top[1] = 1; dfs2(1);

	build(rt[0], 1, n);
	dfs3(1);

	lint lstans = 0;
	for (int i = 1, u, v, k; i <= q; ++i) {
		u = gi(); v = gi(); k = gi();
		u ^= lstans * type; v ^= lstans * type; k ^= lstans * type;
		printf("%lld\n", lstans = (query(k, u) + query(k, v) - query(k, lca(u, v)) - query(k, fa[lca(u, v)])));
	}
	
	return 0;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值