题意:
求一棵树的每对叶子节点之间距离平方的和。
思路:
这个题貌似可以容斥,但是我不会,所以我写了个淀粉质。
要知道,淀粉质的思想就是将子树内部的递归处理,当前这层处理不同子树之间的距离即可,考虑化简式子分别求贡献。
假设
(
a
i
+
a
j
)
2
(a_i+a_j)^2
(ai+aj)2为两点间距离平方和,
a
i
,
a
j
a_i,a_j
ai,aj为叶子到当前找的伪重心的距离,把式子化简出来就是
a
i
2
+
a
j
2
+
2
∗
a
i
∗
a
j
a_i^2+a_j^2+2*a_i*a_j
ai2+aj2+2∗ai∗aj,考虑每一块的贡献。
我们设当前遍历的子树的叶子距离为
a
i
a_i
ai,之前遍历过的子树的叶子距离为
a
j
a_j
aj,个数为
c
n
t
cnt
cnt个,平方和为
s
u
m
1
sum1
sum1,和为
s
u
m
2
sum2
sum2。算这个子树和之前遍历过的子树信息的时候,
a
i
2
a_i^2
ai2的贡献是
c
n
t
∗
a
i
2
cnt*a_i^2
cnt∗ai2,
b
i
2
b_i^2
bi2就是
s
u
m
1
sum1
sum1,
2
∗
a
i
∗
a
j
2*a_i*a_j
2∗ai∗aj的贡献为
2
∗
a
i
∗
s
u
m
2
2*a_i*sum2
2∗ai∗sum2,这样我们就可以分开统计贡献,淀粉质板子套一下就好啦。
//#pragma GCC optimize(2)
#include<cstdio>
#include<iostream>
#include<string>
#include<cstring>
#include<map>
#include<cmath>
#include<cctype>
#include<vector>
#include<set>
#include<queue>
#include<algorithm>
#include<sstream>
#include<ctime>
#include<cstdlib>
#define X first
#define Y second
#define L (u<<1)
#define R (u<<1|1)
#define pb push_back
#define mk make_pair
#define Mid (tr[u].l+tr[u].r>>1)
#define Len(u) (tr[u].r-tr[u].l+1)
#define random(a,b) ((a)+rand()%((b)-(a)+1))
#define db puts("---")
using namespace std;
//void rd_cre() { freopen("d://dp//data.txt","w",stdout); srand(time(NULL)); }
//void rd_ac() { freopen("d://dp//data.txt","r",stdin); freopen("d://dp//AC.txt","w",stdout); }
//void rd_wa() { freopen("d://dp//data.txt","r",stdin); freopen("d://dp//WA.txt","w",stdout); }
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int,int> PII;
const int N=100010,M=N*2,mod=1e9+7,INF=0x3f3f3f3f;
const double eps=1e-6;
int n,m;
int h[N],e[M],ne[M],w[M],idx;
LL p[N],q[N];
int d[N];
bool st[N];
void add(int a,int b,int c)
{
e[idx]=b,w[idx]=c,ne[idx]=h[a],h[a]=idx++;
}
int get_wc(int u,int f,int tot,int &wc,int &mi)
{
if(st[u]) return 0;
int sum=1,mx=0;
for(int i=h[u];~i;i=ne[i])
{
int j=e[i];
if(j==f) continue;
int t=get_wc(j,u,tot,wc,mi);
mx=max(mx,t); sum+=t;
}
mx=max(mx,tot-sum);
if(mx<mi) wc=u,mi=mx;
return sum;
}
int get_size(int u,int f)
{
if(st[u]) return 0;
int sum=1;
for(int i=h[u];~i;i=ne[i])
if(e[i]!=f)
sum+=get_size(e[i],u);
return sum;
}
void get_dis(int u,int f,int dis,int &qt)
{
if(st[u]) return ;
int cnt=0;
for(int i=h[u];~i;i=ne[i])
{
int j=e[i];
if(j==f) continue;
cnt++;
get_dis(j,u,dis+w[i],qt);
}
if(d[u]==1) q[qt++]=dis;
}
LL cal(int u)
{
if(st[u]) return 0;
LL ans=0; int tt=INF;
get_wc(u,-1,get_size(u,-1),u,tt);
st[u]=true;
LL pt=0,pre=0,cnt=0;
for(int i=h[u];~i;i=ne[i])
{
int j=e[i],qt=0;
get_dis(j,-1,w[i],qt);
for(int k=0;k<qt;k++) ans+=1ll*q[k]*q[k]*cnt,ans+=1ll*2*pt*q[k],ans+=pre;
for(int k=0;k<qt;k++) pt+=q[k],pre+=1ll*q[k]*q[k],cnt++;
}
for(int i=h[u];~i;i=ne[i]) ans+=cal(e[i]);
return ans;
}
int main()
{
// ios::sync_with_stdio(false);
// cin.tie(0);
scanf("%d",&n);
memset(h,-1,sizeof(h));
memset(st,false,sizeof(st));
for(int i=1;i<=n-1;i++)
{
int a,b,c; scanf("%d%d%d",&a,&b,&c);
add(a,b,c); add(b,a,c);
d[a]++; d[b]++;
}
printf("%lld\n",cal(1));
return 0;
}
/*
*/