传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=3697
阳视作1 阴视作-1 统计路径为0且祖先有和自己权值相同的个数
Code:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cctype>
#include<map>
#include<set>
using namespace std;
typedef long long LL;
const int maxn=1e5+5;
LL ans=0;
int n;
int getint(){
int res=0;char c=getchar();
while(!isdigit(c))c=getchar();
while(isdigit(c))res=res*10+c-'0',c=getchar();
return res;
}
struct edge{int u,v,w;};
vector<edge>G[maxn];
int siz[maxn],f[maxn],dep[maxn],cant[maxn],root,All,d[maxn];
map<int,int>M,mp,MA,data,data2;
void makert(int u,int fa){
siz[u]=1;f[u]=0;
for(int i=0;i<G[u].size();i++){
edge e=G[u][i];
if(e.v!=fa&&!cant[e.v]){
dep[e.v]=dep[u]+1;
makert(e.v,u);
siz[u]+=siz[e.v];
f[u]=max(f[u],siz[e.v]);
}
}f[u]=max(f[u],All-f[u]);
if(f[root]>f[u])root=u;
}
void dfs(int u,int fa){
data[d[u]]++;
for(int i=0;i<G[u].size();i++){
edge e=G[u][i];
if(e.v==fa||cant[e.v])continue;
d[e.v]=d[u]+e.w;
if(M.count(d[e.v])){
data2[d[e.v]]++;
// cerr<<d[e.v]<<endl;
}
M[d[e.v]]++;
dfs(e.v,u);
if(!--M[d[e.v]])M.erase(d[e.v]);
}
}
typedef map<int,int>::iterator iter;
void deb(map<int,int>M){
puts("");
for(iter it=M.begin();it!=M.end();it++)if(it->second)
cout<<it->first<<" "<<it->second<<endl;
}
void calc(int u){
LL res=0;MA.clear();mp.clear();d[u]=0;
for(int i=0;i<G[u].size();i++){
edge e=G[u][i];
if(cant[e.v])continue;
d[e.v]=e.w;M[e.w]++;data.clear();data2.clear();
dfs(e.v,u);M.erase(e.w);
// deb(MA);
// deb(mp);
// deb(data);
// deb(data2);
for(iter it=data.begin();it!=data.end();it++){
LL num=it->first,cant=it->second-data2[it->first],can=data2[it->first];
if(!num){
res+=it->second*MA[0];
}else{
res+=can*MA[-num];
res+=cant*mp[-num];
}
}
for(iter it=data.begin();it!=data.end();it++)MA[it->first]+=it->second;
for(iter it=data2.begin();it!=data2.end();it++)mp[it->first]+=it->second;
}ans+=res;ans+=mp[0];
}
void solve(int u){
calc(u);cant[u]=1;
for(int i=0;i<G[u].size();i++){
edge e=G[u][i];
if(cant[e.v])continue;
All=siz[e.v];
f[root=0]=n+1;
makert(e.v,0);
solve(root);
}
}
int main(){
n=getint();All=n;
for(int i=1;i<n;i++){
int u=getint(),v=getint(),w=getint();
G[u].push_back((edge){u,v,w?w:-1});
G[v].push_back((edge){v,u,w?w:-1});
}f[root=0]=n+1;
makert(1,1);
solve(root);
cout<<ans<<endl;
return 0;
}