[hdu 5401] Persistent Link/cut Tree


题目链接

首先定义 m+1 棵树 T0Tm 。最开始 T0 只有一个点,编号为 0

接着,对于每一棵树 Ti,在第 Tai棵树的第 ci 个点和第 Tbi 棵树的第 di 个点之间连一条长度为 li 的边。

Ti 中,保持 Tai中的所有节点编号不变,如果 Tai中有 s 个节点,把 Tbi中的所有节点的编号加上 s

假设有一颗大小为 n 的树 T, 定义函数 F(T)=n1i=0n1j=i+1d(i,j)
其中 d(i,j) 为这棵树中 ij 的最短距离。

求出 F(Ti) 的值并输出其对 109+7 取模后的结果 (1im)


首先,我们观察如何通过一种简单的方法计算 F(Ti)

F(Ti)=F(Tai)+F(Tbi)+size[ai]size[bi]li+(d(x,ci)Tai)size[bi]+(d(x,di)Tbi)size[ai]

时间复杂度:O(m2m)


以上算法的瓶颈在于求 d(x,p)T

我们定义函数 count(T,p) 为树 T 中所有结点到 p 的距离和,lence(T,x,y) 为树 Tx 点到 y 点的距离。


那么 F(Ti)=F(Tai)+F(Tbi)+size[ai]size[bi]li+count(Tai,p)size[bi]+count(Tbi,di)size[ai]

现在考虑 p<ai 的情况 ( pai 类似 ) :

count(Ti,p)=count(Tai,p)+count(Tbi,di)+(li+lence(Tai,p,ci))size[bi]

对于求 F(Ti),递归不超过 m 层,而 count(Tai,ci)count(Tbi,di),可以通过预处理或者记忆化搜索来剪枝

所以 count() 函数调用次数复杂度是 O(m)


现在分析 lence(T,x,y) 函数,注意到如果 xy 在同一棵树内,复杂度是 O(m)

如果 xy 不在同一棵树内,我们可以发现 x, y 中至少有一个是 aibi

lence(T,aiTi,bjTj) 是可以在 O(m3) 的时间内计算的,同样可以通过预处理或者记忆化搜索来剪枝

所以 lence() 函数调用次数复杂度是 O(m)


综上所述,总时间复杂度: O(m3)


//hdu 5401

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <ctime>
#include <vector>
#include <utility>
#include <stack>
#include <queue>
#include <map>
#include <iostream>
#include <algorithm>

template<class Num>void read(Num &x)
{
    char c; int flag = 1;
    while((c = getchar()) < '0' || c > '9')
        if(c == '-') flag *= -1;
    x = c - '0';
    while((c = getchar()) >= '0' && c <= '9')
        x = (x<<3) + (x<<1) + (c-'0');
    x *= flag;
    return;
}
template<class Num>void write(Num x)
{
    if(!x) {putchar('0');return;}
    if(x < 0) putchar('-'), x = -x;
    static char s[20];int sl = 0;
    while(x) s[sl++] = x%10 + '0',x /= 10;
    while(sl) putchar(s[--sl]);
}

const int maxm = 65, Mod = 1e9 + 7;

int m, a[maxm], b[maxm];
long long c[maxm], d[maxm], l[maxm];
long long size[maxm];
std::map<std::pair<int,long long>,long long> memo;
std::map<std::pair<long long,long long>,long long> dist[maxm];

long long get_lence(int s,long long x,long long y)
{
    if(s == 0 || x == y) return 0;

    if(x > y) std::swap(x, y);

    if(x < size[a[s]] && y < size[a[s]])
        return get_lence(a[s], x, y);
    else if(x >= size[a[s]] && y >= size[a[s]])
        return get_lence(b[s], x - size[a[s]], y - size[a[s]]);
    else
    {
        std::pair<long long,long long> p = std::make_pair(x, y);
        if(dist[s].count(p)) return dist[s][p];
        dist[s][p] = (get_lence(a[s], c[s], x) + l[s] + get_lence(b[s], d[s], y - size[a[s]])) % Mod;
        return dist[s][p];
    }
}
long long count(int t,long long p)
{
    if(t == 0) return 0;
    std::pair<int,long long> s = std::make_pair(t, p);
    if(memo.count(s)) return memo[s];
    long long ret;
    if(p < size[a[t]])
        ret = (count(a[t], p) + count(b[t], d[t]) + (l[t] + get_lence(a[t], p, c[t])) * (size[b[t]] % Mod) % Mod) % Mod;
    else
        ret = (count(b[t], p - size[a[t]]) + count(a[t], c[t]) + (l[t] + get_lence(b[t], p - size[a[t]], d[t]))* (size[a[t]] % Mod) % Mod) % Mod;
    memo[s] = ret;
    return ret; 
}
void solve()
{
    static long long ans[maxm];
    for(int i = 1; i <= m; i++)
    {
        ans[i] = ans[a[i]] + ans[b[i]];
        ans[i] += (size[a[i]]%Mod)*(size[b[i]]%Mod)%Mod * l[i] % Mod;
        ans[i] += size[b[i]]%Mod * count(a[i], c[i]) % Mod;
        ans[i] += size[a[i]]%Mod * count(b[i], d[i]) % Mod;
        write(ans[i] %= Mod), puts("");
    }
}
bool init()
{
    if(scanf("%d",&m) == EOF) return false;

    for(int i = 1; i <= m; i++)
        dist[i].erase(dist[i].begin(), dist[i].end());
    memo.erase(memo.begin(), memo.end());
    size[0] = 1;
    for(int i = 1; i <= m; i++)
    {
        read(a[i]), read(b[i]), read(c[i]);
        read(d[i]), read(l[i]);
        size[i] = size[a[i]] + size[b[i]];
    }
    return true;
}
int main()
{   
    while(init()) solve();

    return 0;
}
阅读更多 登录后自动展开
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页