首先定义 m+1 棵树 T0 到 Tm 。最开始 T0 只有一个点,编号为 0 。
接着,对于每一棵树
在
Ti
中,保持
Tai
中的所有节点编号不变,如果
Tai
中有
s
个节点,把
假设有一颗大小为
其中
d(i,j)
为这棵树中
i
到
求出 F(Ti) 的值并输出其对 109+7 取模后的结果 (1≤i≤m) 。
首先,我们观察如何通过一种简单的方法计算 F(Ti)
F(Ti)=F(Tai)+F(Tbi)+size[ai]∗size[bi]∗li+(∑d(x,ci)∈Tai)∗size[bi]+(∑d(x,di)∈Tbi)∗size[ai]
时间复杂度: O(m∗2m)
以上算法的瓶颈在于求 ∑d(x,p)∈T
我们定义函数
count(T,p)
为树
T
中所有结点到
那么
现在考虑 p<ai 的情况 ( p≥ai 类似 ) :
count(Ti,p)=count(Tai,p)+count(Tbi,di)+(li+lence(Tai,p,ci))∗size[bi]
对于求
F(Ti)
,递归不超过
m
层,而
所以 count() 函数调用次数复杂度是 O(m)
现在分析
lence(T,x,y)
函数,注意到如果
x
与
如果
x
,
lence(T,ai∈Ti,bj∈Tj) 是可以在 O(m3) 的时间内计算的,同样可以通过预处理或者记忆化搜索来剪枝
所以 lence() 函数调用次数复杂度是 O(m)
综上所述,总时间复杂度: O(m3)
//hdu 5401
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <ctime>
#include <vector>
#include <utility>
#include <stack>
#include <queue>
#include <map>
#include <iostream>
#include <algorithm>
template<class Num>void read(Num &x)
{
char c; int flag = 1;
while((c = getchar()) < '0' || c > '9')
if(c == '-') flag *= -1;
x = c - '0';
while((c = getchar()) >= '0' && c <= '9')
x = (x<<3) + (x<<1) + (c-'0');
x *= flag;
return;
}
template<class Num>void write(Num x)
{
if(!x) {putchar('0');return;}
if(x < 0) putchar('-'), x = -x;
static char s[20];int sl = 0;
while(x) s[sl++] = x%10 + '0',x /= 10;
while(sl) putchar(s[--sl]);
}
const int maxm = 65, Mod = 1e9 + 7;
int m, a[maxm], b[maxm];
long long c[maxm], d[maxm], l[maxm];
long long size[maxm];
std::map<std::pair<int,long long>,long long> memo;
std::map<std::pair<long long,long long>,long long> dist[maxm];
long long get_lence(int s,long long x,long long y)
{
if(s == 0 || x == y) return 0;
if(x > y) std::swap(x, y);
if(x < size[a[s]] && y < size[a[s]])
return get_lence(a[s], x, y);
else if(x >= size[a[s]] && y >= size[a[s]])
return get_lence(b[s], x - size[a[s]], y - size[a[s]]);
else
{
std::pair<long long,long long> p = std::make_pair(x, y);
if(dist[s].count(p)) return dist[s][p];
dist[s][p] = (get_lence(a[s], c[s], x) + l[s] + get_lence(b[s], d[s], y - size[a[s]])) % Mod;
return dist[s][p];
}
}
long long count(int t,long long p)
{
if(t == 0) return 0;
std::pair<int,long long> s = std::make_pair(t, p);
if(memo.count(s)) return memo[s];
long long ret;
if(p < size[a[t]])
ret = (count(a[t], p) + count(b[t], d[t]) + (l[t] + get_lence(a[t], p, c[t])) * (size[b[t]] % Mod) % Mod) % Mod;
else
ret = (count(b[t], p - size[a[t]]) + count(a[t], c[t]) + (l[t] + get_lence(b[t], p - size[a[t]], d[t]))* (size[a[t]] % Mod) % Mod) % Mod;
memo[s] = ret;
return ret;
}
void solve()
{
static long long ans[maxm];
for(int i = 1; i <= m; i++)
{
ans[i] = ans[a[i]] + ans[b[i]];
ans[i] += (size[a[i]]%Mod)*(size[b[i]]%Mod)%Mod * l[i] % Mod;
ans[i] += size[b[i]]%Mod * count(a[i], c[i]) % Mod;
ans[i] += size[a[i]]%Mod * count(b[i], d[i]) % Mod;
write(ans[i] %= Mod), puts("");
}
}
bool init()
{
if(scanf("%d",&m) == EOF) return false;
for(int i = 1; i <= m; i++)
dist[i].erase(dist[i].begin(), dist[i].end());
memo.erase(memo.begin(), memo.end());
size[0] = 1;
for(int i = 1; i <= m; i++)
{
read(a[i]), read(b[i]), read(c[i]);
read(d[i]), read(l[i]);
size[i] = size[a[i]] + size[b[i]];
}
return true;
}
int main()
{
while(init()) solve();
return 0;
}