题意描述:
原先设定第0颗树只有一个节点0,现在要生成第i颗数,选 ai, bi, (ai < i, bi< i) 中两个节点(ci , di)相连接,构成一个新的树,且ai中节点的编号不变, bi中的所有节点编号都要在原来的基础上+ai树的大小,这样保证编号连续,对于每颗树T而言 ,,
F(T)=∑n−1i=0∑n−1j=i+1d(vi,vj)
(
d(vi,vj)
即任意两点之间距离总和。
这是多校题解:
考虑爆搜,树iii生成后,两两点对路径分成两部分,一部分不经过中间的边,那么就是aia_iai和bib_ibi的答案,如果经过中间的边,首先计算中间这条边出现的次数,也就是ai,bia_i,b_iai,bi子树大小的乘积。对于aia_iai,对答案的贡献为所有点到cic_ici的距离和乘上bib_ibi的子树大小。bib_ibi同理。
那么转化为计算在树iii中,所有点到某个点jjj的距离和。假设jjj在aia_iai内,那么就转化成了aia_iai内jjj这个点的距离总和加上bib_ibi内所有点到did_idi的总和加上did_idi到jjj的距离乘上子树bib_ibi的大小,称作第一类询问。
这样就化成了在树iii中两个点jjj和kkk的距离,如果在同一棵子树中,可以递归下去,否则假设jjj在aia_iai中kkk在bib_ibi中,那么距离为jjj到cic_ici的距离加上kkk到did_idi的距离加上lil_ili,称作第二类询问。
然后对两类询问全都记忆化搜索即可。
接着考虑计算一下复杂度。
对于第二类询问,可以考虑询问的过程类似于线段树,只会有两个分支,中间的部分已经记忆化下来,不用再搜,时间复杂度O(m)O(m)O(m)。
我们分析一下复杂度,首先对于第一类询问,在bib_ibi中到did_idi的点距离和已经由前面的询问得到,那么就转化为一个第一类询问和一个第二类询问,最多会被转化成O(m)O(m)O(m)个第二类询问。
所以每个询问复杂度是O(m2)O(m^2)O(m2),总复杂度O(m3)O(m^3)O(m3)。
复杂度计算思考:
对于第一类询问,只会例如sum(a[i], c[i])递归计算时,每个会分成两个第一类询问和一个第二类询问,而两个第一类询问必有一个已经被计算过(可以手动分解看看前后关系)
,所以每次分解成一个第一类和一个第二类,复杂度为m*m。
dis计算也同理。
被记忆的也不会很多,每次最多多记录m*m个。
#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <map>
#include <set>
#include <vector>
#include <cctype>
#include <cmath>
#include <queue>
#define ls rt<<1
#define rs rt<<1|1
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define mem(a,n) memset(a,n,sizeof(a))
#define rep(i,n) for(int i=0;i<(int)n;i++)
#define rep1(i,x,y) for(int i=x;i<=(int)y;i++)
using namespace std;
#pragma comment(linker, "/STACK:102400000,102400000")
typedef pair<int,int> pii;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const ll oo = 1e12;
typedef pair<ll,ll> pll;
const int N = 65;
const int mod = 1e9+7;
map<pll,ll> M[N];
map<ll,ll> M2[N];
int n;
ll a[N],b[N],c[N],d[N],siz[N],ms[N],l[N],ans[N];
void init(){
for(int i = 0; i < N;i++)
M[i].clear(),M2[i].clear();
M[0][pll(0,0)]=0;
M2[0][0] = 0;
siz[0] = ms[0] = 1;
}
ll dis(int i,ll j,ll k){
if(j > k) swap(j,k);
if(M[i].count(pll(j,k))) return M[i][pll(j,k)];
if(k < siz[a[i]]) return M[i][pll(j,k)] = dis(a[i],j,k);
if(j >= siz[a[i]]) return M[i][pll(j,k)] = dis(b[i],j-siz[a[i]],k-siz[a[i]]);
return M[i][pll(j,k)] = (dis(a[i],j,c[i])+l[i]+dis(b[i],d[i],k-siz[a[i]]))%mod;
}
ll sum(int i,ll j){
if(M2[i].count(j)) return M2[i][j];
if(j<siz[a[i]]) return M2[i][j]=(sum(a[i],j)+(l[i]+dis(a[i],j,c[i]))*ms[b[i]]+sum(b[i],d[i]))%mod;
if(j>=siz[a[i]]) return M2[i][j]=(sum(a[i],c[i])+(l[i]+dis(b[i],j-siz[a[i]],d[i]))*ms[a[i]]+sum(b[i],j-siz[a[i]]))%mod;
}
ll cal(int i){
siz[i] = siz[a[i]]+siz[b[i]];
ms[i] = siz[i]%mod;
ans[i] = ans[a[i]]+ans[b[i]]+ms[a[i]]*ms[b[i]]%mod*l[i]%mod+ms[b[i]]*sum(a[i],c[i])+ms[a[i]]*sum(b[i],d[i]);
ans[i]=ans[i]%mod;
return ans[i];
}
int main()
{
while(scanf("%d",&n)==1){
init();
for(int i=1;i<=n;i++){
scanf("%I64d %I64d %I64d %I64d %I64d",&a[i],&b[i],&c[i],&d[i],&l[i]);
printf("%I64d\n",cal(i));
}
}
return 0;
}