Codeforces Round 864 (Div. 2) D 题解

文章介绍了如何解决CodeforcesRound864分类2的D题——LiHua和树。通过使用深度优先搜索(DFS)或广度优先搜索(BFS)来计算子树的大小和重要性,并利用重儿子旋转的概念。文章提到了如何使用SortedList数据结构在Python中以O(1)的时间复杂度进行操作,同时提供了问题的模拟解决方案。
摘要由CSDN通过智能技术生成

Codeforces Round 864 (Div. 2) D题解

Problem D - Li Hua and Tree

原题
我们可以通过dfs求出所有子树的大小,重要性 以及父亲,记为sz数组,imp数组,pa数组

我们令操作2的根节点为x,他的父亲节点为fa,x的heavySon 为y。

然后我们发现在进行子树旋转的时候,其实对于他的父亲节点fa的sz和imp本质上不影响,影响的只是fa的heavySon。而后对于x,他的sz大小应该减少了y的sz大小,imp也少了y的imp大小,并且x的heavySon 应该除去y。而对于y来说,他变成了fa的儿子,并且他的sz变成了原来x的大小,imp也变成了原来x的大小,且y的heavySon 又加上了x这个子树。

旋转后的变化比较清楚。问题在于如何快速求得某个节点的heavySon。
我们可以利用SortedList 以O(1)的时间快速求得,并且SortedList的插入也只耗时O(logN),删除为O(1)
(Java中可以用TreeSet,cpp中用set)
这样我们可以用SortedList去维护每一个节点的有序序列。
接着就是模拟操作就好了

由于CodeForces中对于python的dfs不太友好,大一点深度的dfs会被判RE。
因此我们需要用bfs来代替dfs求sz,imp和pa

import math
import sys
from bisect import bisect_left, bisect_right
from collections import Counter, defaultdict, deque
from itertools import permutations

input = lambda: sys.stdin.readline().rstrip("\r\n")


def I():
    return input()


def ii():
    return int(input())


def li():
    return list(input().split())


def mi():
    return map(int, input().split())


def lii():
    return list(map(int, input().split()))

n,m = mi()
a = [-1] + lii()
adj = [[] for i in range(n + 1)]
hp = [SortedList() for i in range(n + 1)]
for i in range(n - 1):
    u,v = lii()
    adj[u].append(v)
    adj[v].append(u)

sz = [0] * (n + 1)
imp = [0] * (n + 1)
pa = [0] * (n + 1)

stk = []
vis = [0] * (n + 1)
vis[1] = 1
cnt = [0] * (n + 1)
stk.append([1,0])
stk1 = []
while stk:
    o,fa = stk.pop()
    pa[o] = fa
    f = 0
    for nxt in adj[o]:
        if not vis[nxt]:
            cnt[o] += 1
            f = 1
            vis[nxt] = 1
            stk.append([nxt,o])
    if not f:
        stk1.append(o)
while stk1:
    o= stk1.pop()
    sz[o] += 1
    imp[o] += a[o]
    fa = pa[o]
    sz[fa] += sz[o]
    imp[fa] += imp[o]
    cnt[fa]-=1
    hp[fa].add((-sz[o],o,imp[o]))
    if cnt[fa] == 0:
        stk1.append(fa)
# dfs会RE
# def dfs(o,fa):
#     pa[o] = fa
#     ans0 = 0
#     ans1 = 0
#     for nxt in adj[o]:
#         if fa == nxt:continue
#         dfs(nxt,o)
#         ans0 += sz[nxt]
#         ans1 += imp[nxt]
#         hp[o].add((-sz[nxt],nxt,imp[nxt]))
#     sz[o] = ans0 + 1
#     imp[o] = ans1 + a[o]
#
# dfs(1,0)

for _ in range(m):
    f,node = lii()
    if f == 1:
        print(imp[node])
    else:
        if not len(hp[node]):continue
        heavySz,heavyNode,heavyImp = hp[node][0]
        heavySz = -heavySz
        allImp = imp[node]
        allSz = sz[node]
        fa = pa[node]
        hp[fa].discard((-sz[node],node,imp[node]))
        hp[fa].add((-sz[node],heavyNode,allImp))
        hp[node].discard(hp[node][0])
        otherImp = allImp - imp[heavyNode]
        hp[heavyNode].add((-(allSz - heavySz),node,otherImp))
        sz[heavyNode] = allSz
        sz[node] = allSz - heavySz
        pa[heavyNode] = fa
        pa[node] = heavyNode
        imp[node] = otherImp
        imp[heavyNode] = allImp

另附SortedList板子

class SortedList:
    def __init__(self, iterable=[], _load=200):
        """Initialize sorted list instance."""
        values = sorted(iterable)
        self._len = _len = len(values)
        self._load = _load
        self._lists = _lists = [values[i:i + _load] for i in range(0, _len, _load)]
        self._list_lens = [len(_list) for _list in _lists]
        self._mins = [_list[0] for _list in _lists]
        self._fen_tree = []
        self._rebuild = True

    def _fen_build(self):
        """Build a fenwick tree instance."""
        self._fen_tree[:] = self._list_lens
        _fen_tree = self._fen_tree
        for i in range(len(_fen_tree)):
            if i | i + 1 < len(_fen_tree):
                _fen_tree[i | i + 1] += _fen_tree[i]
        self._rebuild = False

    def _fen_update(self, index, value):
        """Update `fen_tree[index] += value`."""
        if not self._rebuild:
            _fen_tree = self._fen_tree
            while index < len(_fen_tree):
                _fen_tree[index] += value
                index |= index + 1

    def _fen_query(self, end):
        """Return `sum(_fen_tree[:end])`."""
        if self._rebuild:
            self._fen_build()

        _fen_tree = self._fen_tree
        x = 0
        while end:
            x += _fen_tree[end - 1]
            end &= end - 1
        return x

    def _fen_findkth(self, k):
        """Return a pair of (the largest `idx` such that `sum(_fen_tree[:idx]) <= k`, `k - sum(_fen_tree[:idx])`)."""
        _list_lens = self._list_lens
        if k < _list_lens[0]:
            return 0, k
        if k >= self._len - _list_lens[-1]:
            return len(_list_lens) - 1, k + _list_lens[-1] - self._len
        if self._rebuild:
            self._fen_build()

        _fen_tree = self._fen_tree
        idx = -1
        for d in reversed(range(len(_fen_tree).bit_length())):
            right_idx = idx + (1 << d)
            if right_idx < len(_fen_tree) and k >= _fen_tree[right_idx]:
                idx = right_idx
                k -= _fen_tree[idx]
        return idx + 1, k

    def _delete(self, pos, idx):
        """Delete value at the given `(pos, idx)`."""
        _lists = self._lists
        _mins = self._mins
        _list_lens = self._list_lens

        self._len -= 1
        self._fen_update(pos, -1)
        del _lists[pos][idx]
        _list_lens[pos] -= 1

        if _list_lens[pos]:
            _mins[pos] = _lists[pos][0]
        else:
            del _lists[pos]
            del _list_lens[pos]
            del _mins[pos]
            self._rebuild = True

    def _loc_left(self, value):
        """Return an index pair that corresponds to the first position of `value` in the sorted list."""
        if not self._len:
            return 0, 0

        _lists = self._lists
        _mins = self._mins

        lo, pos = -1, len(_lists) - 1
        while lo + 1 < pos:
            mi = (lo + pos) >> 1
            if value <= _mins[mi]:
                pos = mi
            else:
                lo = mi

        if pos and value <= _lists[pos - 1][-1]:
            pos -= 1

        _list = _lists[pos]
        lo, idx = -1, len(_list)
        while lo + 1 < idx:
            mi = (lo + idx) >> 1
            if value <= _list[mi]:
                idx = mi
            else:
                lo = mi

        return pos, idx

    def _loc_right(self, value):
        """Return an index pair that corresponds to the last position of `value` in the sorted list."""
        if not self._len:
            return 0, 0

        _lists = self._lists
        _mins = self._mins

        pos, hi = 0, len(_lists)
        while pos + 1 < hi:
            mi = (pos + hi) >> 1
            if value < _mins[mi]:
                hi = mi
            else:
                pos = mi

        _list = _lists[pos]
        lo, idx = -1, len(_list)
        while lo + 1 < idx:
            mi = (lo + idx) >> 1
            if value < _list[mi]:
                idx = mi
            else:
                lo = mi

        return pos, idx

    def add(self, value):
        """Add `value` to sorted list."""
        _load = self._load
        _lists = self._lists
        _mins = self._mins
        _list_lens = self._list_lens

        self._len += 1
        if _lists:
            pos, idx = self._loc_right(value)
            self._fen_update(pos, 1)
            _list = _lists[pos]
            _list.insert(idx, value)
            _list_lens[pos] += 1
            _mins[pos] = _list[0]
            if _load + _load < len(_list):
                _lists.insert(pos + 1, _list[_load:])
                _list_lens.insert(pos + 1, len(_list) - _load)
                _mins.insert(pos + 1, _list[_load])
                _list_lens[pos] = _load
                del _list[_load:]
                self._rebuild = True
        else:
            _lists.append([value])
            _mins.append(value)
            _list_lens.append(1)
            self._rebuild = True

    def discard(self, value):
        """Remove `value` from sorted list if it is a member."""
        _lists = self._lists
        if _lists:
            pos, idx = self._loc_right(value)
            if idx and _lists[pos][idx - 1] == value:
                self._delete(pos, idx - 1)

    def remove(self, value):
        """Remove `value` from sorted list; `value` must be a member."""
        _len = self._len
        self.discard(value)
        if _len == self._len:
            raise ValueError('{0!r} not in list'.format(value))

    def pop(self, index=-1):
        """Remove and return value at `index` in sorted list."""
        pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
        value = self._lists[pos][idx]
        self._delete(pos, idx)
        return value

    def bisect_left(self, value):
        """Return the first index to insert `value` in the sorted list."""
        pos, idx = self._loc_left(value)
        return self._fen_query(pos) + idx

    def bisect_right(self, value):
        """Return the last index to insert `value` in the sorted list."""
        pos, idx = self._loc_right(value)
        return self._fen_query(pos) + idx

    def count(self, value):
        """Return number of occurrences of `value` in the sorted list."""
        return self.bisect_right(value) - self.bisect_left(value)

    def __len__(self):
        """Return the size of the sorted list."""
        return self._len

    def __getitem__(self, index):
        """Lookup value at `index` in sorted list."""
        pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
        return self._lists[pos][idx]

    def __delitem__(self, index):
        """Remove value at `index` from sorted list."""
        pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
        self._delete(pos, idx)

    def __contains__(self, value):
        """Return true if `value` is an element of the sorted list."""
        _lists = self._lists
        if _lists:
            pos, idx = self._loc_left(value)
            return idx < len(_lists[pos]) and _lists[pos][idx] == value
        return False

    def __iter__(self):
        """Return an iterator over the sorted list."""
        return (value for _list in self._lists for value in _list)

    def __reversed__(self):
        """Return a reverse iterator over the sorted list."""
        return (value for _list in reversed(self._lists) for value in reversed(_list))

    def __repr__(self):
        """Return string representation of sorted list."""
        return 'SortedList({0})'.format(list(self))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值