Description
给定一颗n个结点的无根树,树上的每个点有一个非负整数点权,定义一条路径的价值为路径上的点权和-路径的点权最大值。
给定参数p,我们想知道,有多少不同的树上简单路径,满足它的价值恰好是p的倍数。
注意:单点算作一个路径;u ≠ v时,(u,v)和(v,u)只算一次。
Data Constraint
对所有测试点,我们有:
n≤10^5,p≤10^7,val_i≤10^9
Solution
这是道树分治的题。我们找出重心的位置,每次从重心往四周遍历,找出每条到重心的路径的点权和%p和路径的点权最大值,然后将路径按点权最大值从小大大排序,用个桶维护当前的路径的点权和,每次在桶中查找路径的点权和-路径的点权最大值的数量。由于可能会算重,所以要先重心的每颗子树自己先搞一下,减去重复。
Code
#include<iostream>
#include<cmath>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=2e5+5,maxn1=1e7+5;
struct code{
int mx,sum;
}b[maxn];
int first[maxn],last[maxn],next[maxn],a[maxn],size[maxn],mx[maxn];
int n,i,t,j,k,l,m,x,y,z,num,p,ans,ln,s,cnt[maxn1][2],bz[maxn],fa[maxn];
void lian(int x,int y){
last[++num]=y;next[num]=first[x];first[x]=num;
}
bool cmp(code x,code y){
return x.mx<y.mx;
}
void dg1(int x,int y){
int t,p=num;size[x]=1;mx[x]=0;
for (t=first[x];t;t=next[t]){
if (last[t]==y || bz[last[t]])continue;
b[++num].sum=(b[p].sum+a[last[t]])%m;
b[num].mx=max(a[last[t]],b[p].mx);
dg1(last[t],x);size[x]+=size[last[t]];mx[x]=max(mx[x],size[last[t]]);
}
}
int find(int x,int y){
int t,k;mx[x]=max(mx[x],p-size[x]);
if (mx[x]*2<=p||p==1) return x;
for(t=first[x];t;t=next[t]){
if (last[t]==y || bz[last[t]]) continue;
k=find(last[t],x);
if (k) return k;
}
return 0;
}
void dg(int x){
int t,k;
bz[x]=1;num=0;
for (t=first[x];t;t=next[t]){
if (bz[last[t]]) continue;
k=num+1;
b[++num].mx=max(a[x],a[last[t]]);b[num].sum=(a[x]+a[last[t]])%m;
dg1(last[t],0);
sort(b+k,b+num+1,cmp);
for (i=k;i<=num;i++){
k=((b[i].sum-b[i].mx)%m+m)%m;
if (k) k=m-k;
if (cnt[k][0]==last[t]) ans-=cnt[k][1];
l=((b[i].sum-a[x])%m+m)%m;
if (cnt[l][0]!=last[t]) cnt[l][0]=last[t],cnt[l][1]=0;
cnt[l][1]++;
}
}
sort(b+1,b+num+1,cmp);
for (i=1;i<=num;i++){
k=((b[i].sum-b[i].mx)%m+m)%m;
if (k) k=m-k;else ans++;
if (cnt[k][0]==x) ans+=cnt[k][1];
l=((b[i].sum-a[x])%m+m)%m;
if (cnt[l][0]!=x) cnt[l][0]=x,cnt[l][1]=0;
cnt[l][1]++;
}
for (t=first[x];t;t=next[t]){
if (bz[last[t]])continue;p=size[last[t]];
k=find(last[t],x);
dg(k);
}
}
int main(){
freopen("path.in","r",stdin);freopen("path.out","w",stdout);
scanf("%d%d",&n,&m);
for (i=1;i<n;i++)
scanf("%d%d",&x,&y),lian(x,y),lian(y,x);
for (i=1;i<=n;i++)
scanf("%d",&a[i]);num=0;
dg1(1,0);p=n;
k=find(1,0);
dg(k);
ans+=n;
printf("%d\n",ans);
}