题意:在一棵树上找三个点,这三个点不构成一条链的有多少种取法。
解题思路:把题意转化一下,找三个点能够构成一条链的有多少种取法,然后用用总数减掉就可以了。那么怎么找构成链的取法呢?对于某一个点,我们认为这个点是必取的,然后再从其他地方取两个,把枚举所有点时得到的方案数累加起来,再除以3,就是答案。为什么除以3?因为一条链上的三个点,对于枚举任一个点的时候,这条链都被计算了一次。现在问题就是对于这个点,我们怎么算答案了。先预处理一下,dp[u]表示u所掌控的子树,取两个点,在一条链上的取法有几种,size[u]表示u掌控的子树的大小。那么dp[u]就是所有的dp[v](v表示u的儿子)累加起来,加上size[u]-1了。然后就从根进入,对于每个当前访问的节点,这个点是必取的方案数是多少呢?当前这个点必取,那么就看剩下的两个点了,如果剩下的两个点在同一颗子树下,那么肯定是链状的,也就是dp[v]了,如果不在同一颗子树下,枚举某一颗子树,在这颗子树下取一个点,方案数自然是size[v],在前面的子树里取一个点,方案数是cnt(这是一个累加器,算完v这个节点后把size[v]加进去)那么这两个方案数乘积就是能组合的方案数了,再把每次枚举某一子树的组合出的种数累加起来,在加上之前的dp[v]的累加值,就是对于u这个节点必取是的方案数了。但是我们还漏了一部分,就是u的父亲节点那部分,这个就是往下传就行了,跟以前的树形dp类似。
#pragma comment(linker, "/STACK:16777216")
#include<stdio.h>
#include<string.h>
#include<algorithm>
#define ll __int64
using namespace std ;
const int maxn = 111111 ;
struct Edge {
int t , next ;
} edge[maxn<<1] ;
int head[maxn] , tot ;
void new_edge ( int a , int b ) {
edge[tot].t = b ;
edge[tot].next = head[a] ;
head[a] = tot ++ ;
}
int size[maxn] ;
ll dp[maxn] , ans ;
void cal ( int u , int fa ) {
int i ;
for ( i = head[u] ; i != -1 ; i = edge[i].next ) {
int v = edge[i].t ;
if ( v == fa ) continue ;
cal ( v , u ) ;
size[u] += size[v] ;
dp[u] += dp[v] + size[v] ;
}
}
int num[maxn] ;
void dfs ( int u , int fa , ll fv , int fsize ) {
int i , m = 0 ;
ll cnt = fsize , cnt2 = fv ;
ans += fv ;
for ( i = head[u] ; i != -1 ; i = edge[i].next ) {
int v = edge[i].t ;
if ( v == fa ) continue ;
ans += dp[v] ;
ans += (ll) size[v] * cnt ;
cnt += size[v] ;
cnt2 += dp[v] ;
}
// printf ( "u = %d , ans = %I64d\n" , u , ans ) ;
for ( i = head[u] ; i != -1 ; i = edge[i].next ) {
int v = edge[i].t ;
if ( v == fa ) continue ;
int ffsize = cnt - size[v] + 1 ;
ll ffv = cnt2 - dp[v] + ffsize - 1 ;
dfs ( v , u , ffv , ffsize ) ;
}
}
int main () {
int n , i , a , b ;
while ( scanf ( "%d" , &n ) != EOF ) {
tot = ans = 0 ;
for ( i = 0 ; i <= n ; i ++ ) head[i] = -1 , dp[i] = 0 , size[i] = 1 ;
for ( i = 1 ; i < n ; i ++ ) {
scanf ( "%d%d" , &a , &b ) ;
new_edge ( a , b ) ;
new_edge ( b , a ) ;
}
cal ( 1 , 0 ) ;
// for ( i = 1 ; i <= n ; i ++ ) printf ( "%I64d %d\n" , dp[i] , size[i] ) ; puts ( "" ) ;
dfs ( 1 , 0 , 0 , 0 ) ;
ll tot = (ll) n * ( n - 1 ) * ( n - 2 ) / 6 ;
// printf ( "%I64d\n" , ans ) ;
printf ( "%I64d\n" , tot - ans / 3 ) ;
}
}
/*
7
1 2
1 3
2 4
2 5
3 6
6 7
*/