#505. 三角果计数
三角果计数 - 题目 - Daimayuan Online Judge
思路:我们能够发现,如果三个点在一条直线上是一定不满足的,那么如果三个点不在一条直线上呢?
假设三个点的状态为图示状态,那么三条边为a+b,a+c,b+c,我们能够发现任意两边之和都是大于第三边的,并且任意两边之差是小于第三边的,所以一定可以构成一个三角形,所以我们知道了只要不是三点共线,那么就可以形成三角果,所以我们可以用总数,即任意选出三个点的数量减去三点共线的数量,那么这个题就转化为了一个树形dp,对于一个节点,那么如果两个点都是它的儿子,只要它的儿子不在同一个子树上就满足,所以我们还是可以直接加上任意选出两个点的情况然后减去在同一个子树上的情况,另外,还有只选一个儿子,其余的从父节点那边选择,这个也能够直接算出来
看这个图就是如果我们子儿子选择两个,那么我们可以选择在1,2或者1,3等等,但是不能选择1,2(因为我们选择的是以当前点为中间点的,但是如果在同一个子树上,那么中间点可能不是当前这个,而且如果在同一个子树上可能会形成三角果,所以我们不能选择在同一个子树上的,但是这并不会使得方案数减少,因为我们还会去递归这个子树)同时我们还要加上父节点的部分
#include<iostream>
#include<cstring>
#include<string>
#include<sstream>
#include<cmath>
#include<cstdio>
#include<algorithm>
#include<queue>
#include<map>
#include<stack>
#include<vector>
#include<set>
#include<unordered_map>
#include<ctime>
#include<cstdlib>
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef double db;
typedef pair<int,int> PII;
typedef pair<int,pair<int,int> > PIII;
const double eps=1e-7;
const int N=5e5+7 ,M=5e5+7, INF=0x3f3f3f3f,mod=1e9+7;
const long long int llINF=0x3f3f3f3f3f3f3f3f;
inline ll read() {ll x=0,f=1;char c=getchar();while(c<'0'||c>'9') {if(c=='-') f=-1;c=getchar();}
while(c>='0'&&c<='9') {x=(ll)x*10+c-'0';c=getchar();} return x*f;}
inline void write(ll x) {if(x < 0) {putchar('-'); x = -x;}if(x >= 10) write(x / 10);putchar(x % 10 + '0');}
inline void write(ll x,char ch) {write(x);putchar(ch);}
void stin() {freopen("in_put.txt","r",stdin);freopen("my_out_put.txt","w",stdout);}
bool cmp0(int a,int b) {return a>b;}
template<typename T> T gcd(T a,T b) {return b==0?a:gcd(b,a%b);}
template<typename T> T lcm(T a,T b) {return a*b/gcd(a,b);}
void hack() {printf("\n----------------------------------\n");}
int T,hackT;
int n,m,k;
int h[N],e[M],ne[M],w[M],idx;
int sizes[N];
ll ans=0;
void add(int a,int b,int c) {
e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;
}
void dfs1(int u,int fa) {
sizes[u]=1;
for(int i=h[u];i!=-1;i=ne[i]) {
int j=e[i];
if(j==fa) continue;
dfs1(j,u);
sizes[u]+=sizes[j];
}
}
ll get(int n,int m) {
if(m>n) return 0;
ll res=1;
for(int i=n,j=1;j<=m;j++,i--) res=(ll)res*i;
for(int j=1;j<=m;j++) res=res/j;
return res;
}
void dfs(int u,int fa) {
ans+=(ll)(sizes[u]-1)*(n-sizes[u]);
ans+=get(sizes[u]-1,2);
for(int i=h[u];i!=-1;i=ne[i]) {
int j=e[i];
if(j==fa) continue;
dfs(j,u);
ans-=get(sizes[j],2);
}
}
void solve() {
n=read();
memset(h,-1,sizeof h);
for(int i=1;i<n;i++) {
int a=read(),b=read(),c=read();
add(a,b,c),add(b,a,c);
}
dfs1(1,-1);
dfs(1,-1);
printf("%lld\n",get(n,3)-ans);
}
int main() {
// init();
// stin();
// scanf("%d",&T);
T=1;
while(T--) hackT++,solve();
return 0;
}