题意
有
m
棵树
对于
0≤ai,bi<i,0≤li≤109
1≤m≤60
数据组数
T≤100
Time
Limits:2000ms
Memory
Limits:512M
分析
令
ans[i]
为
Ti
的答案,
size[i]
为
Ti
的节点个数。
易得
ans[i]=ans[ai]+ans[bi]+size[ai]∗size[bi]∗li+topai,ci∗size[bi]+topbi,di∗size[ai]
topk,x
表示
Tk
中所有节点到
x
号节点的距离和。
现在我们看看如何求
考虑把
Tk
分成
Tak,Tbk
。
1°
k=0
时,值为
0
;
2°
其中
disk,x,y
表示
Tk
中
x,y
的距离。
3°
x>=size[ak]
时,
x
原来是在
这样就可以递归求解
topk,x
了。
我们的问题又变成了如何求
disk,x,y
。
同样考虑递归求解。
如果
x,y
同属于组成
Tk
的两棵树中的一棵,我们可以递归下去;否则令
x<y
,
disk,x,y=disak,x,ck+disbk,y,dk+lk
,这样变成了两个子问题,继续递归求解。边界为
k=0或x=y
时,
disk,x,y=0
。
上面 top,dis 的求值过程可以用记忆化搜索。
求
disk,x,y
的过程最多递归
m
层,时间复杂度为
求
topk,x,y
的过程最多有
m
层,每层一个求
所以总时间复杂度是
O(m3)
代码
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <map>
using namespace std;
typedef long long LL;
typedef pair<LL,LL> pa;
const int N = 66;
const LL P = 1e9 + 7;
int n,a[N],b[N],len[N];
LL size[N],ans[N],c[N],d[N];
map< pa , int > dis[N];
void init() {
scanf("%d",&n);
for (int i = 1;i <= n;i ++) {
scanf("%d%d%lld%lld%d",&a[i],&b[i],&c[i],&d[i],&len[i]);
dis[i].clear();
}
}
int get(int k,LL x,LL y) {
if (k == 0 || x == y) return 0;
if (y < x) swap(x,y);
if (y < size[a[k]]) return get(a[k],x,y);
if (x >= size[a[k]]) return get(b[k],x - size[a[k]],y - size[a[k]]);
pa cur = make_pair(x,y);
if (dis[k].find(cur) != dis[k].end()) return dis[k][cur];
int re = (0LL + get(a[k],x,c[k]) + get(b[k],y - size[a[k]],d[k]) + len[k]) % P;
return dis[k][cur] = re;
}
int top(int k,LL x) {
if (k == 0) return 0;
pa cur = make_pair(x,-1);
if (dis[k].find(cur) != dis[k].end()) return dis[k][cur];
int re = 0;
if (x < size[a[k]]) {
re = top(a[k],x);
re = (re + top(b[k],d[k])) % P;
re = (re + 1LL * (get(a[k],x,c[k]) + len[k]) % P * (size[b[k]] % P) % P) % P;
}
else {
re = top(b[k],x - size[a[k]]);
re = (re + top(a[k],c[k])) % P;
re = (re + 1LL * size[a[k]] % P * (get(b[k],x - size[a[k]],d[k]) + len[k]) % P) % P;
}
return dis[k][cur] = re;
}
void solve() {
for (int i = 1;i <= n;i ++) {
int l = a[i],r = b[i];
ans[i] = (ans[l] + ans[r]) % P;
size[i] = size[l] + size[r];
ans[i] = (ans[i] + size[l] % P * (size[r] % P) % P * LL(len[i]) % P) % P;
ans[i] = (ans[i] + 1LL * size[l] % P * top(r,d[i]) % P) % P;
ans[i] = (ans[i] + 1LL * size[r] % P * top(l,c[i]) % P) % P;
printf("%lld\n",ans[i]);
}
}
int main() {
int T;
scanf("%d",&T);
size[0] = 1;
while (T --) {
init();
solve();
}
}