# [hdu 5401] Persistent Link/cut Tree

Ti$T_i$ 中，保持 Tai$T_{a_i}$中的所有节点编号不变，如果 Tai$T_{a_i}$中有 s$s$ 个节点，把 Tbi$T_{b_i}$中的所有节点的编号加上 s$s$

F(Ti)=F(Tai)+F(Tbi)+size[ai]size[bi]li+(d(x,ci)Tai)size[bi]+(d(x,di)Tbi)size[ai]$F(T_i) = F(T_{a_i}) + F(T_{b_i}) + size[a_i]*size[b_i]*l_i +(\sum d(x,c_i) \in T_{a_i}) * size[b_i] + (\sum d(x,d_i) \in T_{b_i}) * size[a_i]$

count(Ti,p)=count(Tai,p)+count(Tbi,di)+(li+lence(Tai,p,ci))size[bi]$count(T_i,p) = count(T_{a_i}, p) + count(T_{b_i}, d_i) + (l_i + lence(T_{a_i}, p, c_i)) * size[b_i]$

lence(T,aiTi,bjTj)$lence(T,a_i\in T_i,b_j \in T_j)$ 是可以在 O(m3)$O(m^3)$ 的时间内计算的，同样可以通过预处理或者记忆化搜索来剪枝

//hdu 5401

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <ctime>
#include <vector>
#include <utility>
#include <stack>
#include <queue>
#include <map>
#include <iostream>
#include <algorithm>

{
char c; int flag = 1;
while((c = getchar()) < '0' || c > '9')
if(c == '-') flag *= -1;
x = c - '0';
while((c = getchar()) >= '0' && c <= '9')
x = (x<<3) + (x<<1) + (c-'0');
x *= flag;
return;
}
template<class Num>void write(Num x)
{
if(!x) {putchar('0');return;}
if(x < 0) putchar('-'), x = -x;
static char s[20];int sl = 0;
while(x) s[sl++] = x%10 + '0',x /= 10;
while(sl) putchar(s[--sl]);
}

const int maxm = 65, Mod = 1e9 + 7;

int m, a[maxm], b[maxm];
long long c[maxm], d[maxm], l[maxm];
long long size[maxm];
std::map<std::pair<int,long long>,long long> memo;
std::map<std::pair<long long,long long>,long long> dist[maxm];

long long get_lence(int s,long long x,long long y)
{
if(s == 0 || x == y) return 0;

if(x > y) std::swap(x, y);

if(x < size[a[s]] && y < size[a[s]])
return get_lence(a[s], x, y);
else if(x >= size[a[s]] && y >= size[a[s]])
return get_lence(b[s], x - size[a[s]], y - size[a[s]]);
else
{
std::pair<long long,long long> p = std::make_pair(x, y);
if(dist[s].count(p)) return dist[s][p];
dist[s][p] = (get_lence(a[s], c[s], x) + l[s] + get_lence(b[s], d[s], y - size[a[s]])) % Mod;
return dist[s][p];
}
}
long long count(int t,long long p)
{
if(t == 0) return 0;
std::pair<int,long long> s = std::make_pair(t, p);
if(memo.count(s)) return memo[s];
long long ret;
if(p < size[a[t]])
ret = (count(a[t], p) + count(b[t], d[t]) + (l[t] + get_lence(a[t], p, c[t])) * (size[b[t]] % Mod) % Mod) % Mod;
else
ret = (count(b[t], p - size[a[t]]) + count(a[t], c[t]) + (l[t] + get_lence(b[t], p - size[a[t]], d[t]))* (size[a[t]] % Mod) % Mod) % Mod;
memo[s] = ret;
return ret;
}
void solve()
{
static long long ans[maxm];
for(int i = 1; i <= m; i++)
{
ans[i] = ans[a[i]] + ans[b[i]];
ans[i] += (size[a[i]]%Mod)*(size[b[i]]%Mod)%Mod * l[i] % Mod;
ans[i] += size[b[i]]%Mod * count(a[i], c[i]) % Mod;
ans[i] += size[a[i]]%Mod * count(b[i], d[i]) % Mod;
write(ans[i] %= Mod), puts("");
}
}
bool init()
{
if(scanf("%d",&m) == EOF) return false;

for(int i = 1; i <= m; i++)
dist[i].erase(dist[i].begin(), dist[i].end());
memo.erase(memo.begin(), memo.end());
size[0] = 1;
for(int i = 1; i <= m; i++)
{
}