题意
后两个求和符号代表的是有多少异或值为0的路径。
前两个符号代表有多少路径包含异或值为0的路径。
即每个权值为0的路径对答案的贡献为 有多少路径包含当前路径 ,所有的贡献加起来就是答案。
思路
点分治,权值太大,桶只能用map了。
x与y之间的路径(x与y间权值异或为0)对答案的贡献是
(x可以扩展的点数)**(y可以扩展的点数)。
如上图,设2与3之间的权值异或为0,那么这条边对答案的贡献为2*3(2可以扩展到2、4两个点,3可以扩展到3、5、6这三个点)。
桶里面加上的不是个数了,而是对于 当前重心 到 x 的 这一路径 ,x可以扩展到几个点。如果知道对于每条路径上面的端点可以扩展到多少点这题就可做了。
预处理:让a作为根,遍历一遍,统计每个点子树的大小siz[x],与每个点的上一个点init_pre[x]。
现在要求rt->…..->pre->x这一路径x可以扩展多少:
1、init_pre[x]==pre,那么x可以扩展的点就是siz[x].
2、init_pre[x]!=pre,那么x可以扩展的点是n-siz[pre]
可以o(1)求出来。
解题思路:
树上路径计数问题,想当然就是点分治了,但是这道题并没有那么简单。
以这个图为例:
首先是路径两端点数的计算问题:
假设我们现在点分治的过程中以3为重心,3-4和2-3都为2,因此2-4这个路径的异或和为0,我们期望用2号节点及其上面的点的数量乘以4号节点以及其下面的点的数量来得到2-4这条路径的贡献。但是点分治里的sz数组是不固定且不正确的,会随着重心的不同发生变化,所以我们用最初第一次以1为根获得的siz进行处理,在第一次getrt的过程中记录每个节点的前驱节点,如果在这次遍历过程中前驱节点与第一次getrt过程中的前驱节点相同的话,代表当前方向与第一次getrt时的方向相同,就可以直接用第一次getrt处理到的siz,否则说明方向相反,就要用总的点数 n 减去当前节点前驱节点的siz获得结果,比如2号节点及其上方的点数就可以用总点数5减去2的前驱3及其3下方点的数量得到。
第二个问题是当前重心到某一点的异或和直接为零的情况,需要特殊处理:
假设图中4-5的权值为2,我们依然假设3为当前重心,这样3-5的异或和为0,我们要单独去判断一下重心另一侧的点的数量再相乘(因为这一部分再遍历map的时候是没有处理的)
#include<bits/stdc++.h>
#include<>
using namespace std;
#define ll long long
#define maxn 100100
#define inf 0x3f3f3f3f
#define mod 1000000007
struct node
{
ll v,w,to;
} edge[maxn*2];
struct data
{
ll w,sum;
}temp[maxn];
bool vis[maxn];
int head[maxn],cnt,n,rt,a,tot;
ll sum,sz[maxn],maxx[maxn],ans,b,sz_rt;
ll dis[maxn],siz[maxn],init_pre[maxn];
unordered_map<ll,ll>mp;
void init()
{
memset(head,-1,sizeof(head));
memset(vis,0,sizeof(vis));
cnt=ans=0;
}
void add(int u,int v,ll w)
{
edge[cnt]={v,w,head[u]};
head[u]=cnt++;
edge[cnt]={u,w,head[v]};
head[v]=cnt++;
}
void dfs(int x,int pre)
{
siz[x]=1;
init_pre[x]=pre;
for(int i=head[x];i!=-1;i=edge[i].to)
{
int v=edge[i].v;
if(v!=pre)
{
dfs(v,x);
siz[x]+=siz[v];
}
}
}
void getrt(int x,int pre)
{
sz[x]=1;
maxx[x]=0;
for(int i=head[x];i!=-1;i=edge[i].to)
{
int v=edge[i].v;
if(v!=pre&&!vis[v])
{
getrt(v,x);
sz[x]+=sz[v];
maxx[x]=max(maxx[x],sz[v]);
}
}
maxx[x]=max(maxx[x],sum-sz[x]);
if(maxx[x]<maxx[rt])rt=x;
}
void getdis(int x,int pre)
{
if(pre==init_pre[x])
{
ans=(ans+siz[x]*mp[dis[x]])%mod;
temp[++tot]= {dis[x],siz[x]};
if(dis[x]==0)
ans=(ans+siz[x]*sz_rt)%mod;
}
else
{
ans=(ans+1ll*(n-siz[pre])*mp[dis[x]])%mod;
temp[++tot]= {dis[x],n-siz[pre]};
if(dis[x]==0)
ans=(ans+1ll*(n-siz[pre])*sz_rt)%mod;
}
for(int i=head[x];i!=-1;i=edge[i].to)
{
int v=edge[i].v;
ll w=edge[i].w;
if(v==pre||vis[v])continue;
dis[v]=dis[x]^w;
getdis(v,x);
}
}
void cal(int x)
{
dis[x]=0;
for(int i=head[x];i!=-1;i=edge[i].to)
{
int v=edge[i].v;
ll w=edge[i].w;
if(vis[v])continue;
if(init_pre[v]==x)sz_rt=n-siz[v];
else sz_rt=siz[rt];
dis[v]=w;
tot=0;
getdis(v,x);
for(int j=1;j<=tot;j++)
{
mp[temp[j].w]+=temp[j].sum;
mp[temp[j].w]%=mod;
}
}
mp.clear();
}
void solve(int x)
{
vis[x]=1;
cal(x);
for(int i=head[x];i!=-1;i=edge[i].to)
{
int v=edge[i].v;
if(vis[v])continue;
maxx[rt=0]=inf;
sum=sz[v];
getrt(v,0);
solve(rt);
}
}
int main()
{
init();
scanf("%d",&n);
for(int i=2;i<=n;i++)
{
scanf("%d%lld",&a,&b);
add(i,a,b);
}
dfs(1,0);
maxx[rt=0]=inf;
sum=n;
getrt(1,0);
solve(rt);
printf("%lld\n",ans%mod);
return 0;
}