Input:
4
1 2 1
1 3 2018
1 4 1
4
1 2 0
1 3 0
1 4 0
3
1 2 1
2 3 1
Output:
2
6
0
解法: 裸的点分治
Code:
#include <bits/stdc++.h>
#define clr(a,b) memset(a,b,sizeof(a));
using namespace std;
const int MX = 2e4 + 7;
struct Edge{
int v,w,next;
}e[MX << 1];
int ecnt,head[MX];
void add(int u,int v,int w){
e[++ecnt].v = v;
e[ecnt].w = w;
e[ecnt].next = head[u];
head[u] = ecnt;
}
int n,sum,cnt,res,rt;
int siz[MX],maxp[MX],tmp[MX],dis[MX],judge[2022];
bool vis[MX];
void getrt(int u,int fa){
siz[u] = 1, maxp[u] = 0;
for(int i = head[u];i;i = e[i].next){
int v = e[i].v;
if(vis[v] || v == fa) continue;
getrt(v,u);
siz[u] += siz[v];
maxp[u] = max(maxp[u],siz[v]);
}
maxp[u] = max(maxp[u], sum - siz[u]);
if(maxp[u] < maxp[rt]) rt = u;
}
void getdis(int u,int fa){
tmp[cnt++] = dis[u] % 2019;
for(int i = head[u];i;i = e[i].next){
int v = e[i].v;
if(v == fa || vis[v]) continue;
dis[v] = dis[u] + e[i].w;
getdis(v,u);
}
}
void solve(int u){
queue<int>q;
for(int i = head[u];i;i = e[i].next){
int v = e[i].v;
if(vis[v]) continue;
cnt = 0;
dis[v] = e[i].w;
getdis(v,u);
for(int j = 0;j < cnt;++j){
if(tmp[j] == 0){
res++;
if(judge[0]) res += judge[0];
} else if(judge[2019 - tmp[j]]) res += judge[2019 - tmp[j]];
}
for(int j = 0;j < cnt;++j)
q.push(tmp[j]), judge[tmp[j]]++;
}
while(!q.empty()){
judge[q.front()] = 0; q.pop();
}
}
void divide(int u){
vis[u] = true;
solve(u);
for(int i = head[u];i;i = e[i].next){
int v = e[i].v;
if(vis[v]) continue;
maxp[rt=0] = sum = siz[v];
getrt(v,0);getrt(rt,0);
divide(rt);
}
}
void init(){
res = ecnt = rt = sum = cnt = 0;
clr(judge,0);
for(int i = 0;i <= n;++i){
head[i] = maxp[i] = dis[i] = siz[i] = 0;
vis[i] = false;
}
}
int main(){
while(~scanf("%d",&n)){
for(int i = 1;i < n;++i){
int u,v,w;scanf("%d %d %d",&u,&v,&w);
add(u,v,w);add(v,u,w);
}
maxp[0] = sum = n;
getrt(1,0);
getrt(rt,0);
divide(rt);
printf("%d\n", res);
init();
}
}