题目描述
圣诞节到了,小可可送给小薰一棵圣诞树。这棵圣诞树很奇怪,它是一棵多叉树,有n个点,n-1条边。它的每个结点都有一个权值。小可可和小薰想用这棵树玩一个游戏。
定义(s,e)为树上从s到e的简单路径,我们可以记下在这条路径上经过的结点,定义这个结点序列为S(s,e)。
我们按照如下方法定义这个序列S(s,e)的权值G(S(s,e)):假设这个序列中结点的权值为Z0,Z1,…,Z(L-1),其中L为序列的长度,我们定义G(S(s,e))=Z0 × k0 + Z1 × k1 + … + Z(L-1) × k(L-1)。
如果路径(s,e)满足G(S(s,e)) ≡ x (mod y) ,那么这条路径属于小可可,否则这条路径属于小薰。小可可和小薰很显然不希望这个游戏变得那么简单。小薰认为如果路径(p1,p2)和(p2,p3)都属于他,那么路径(p1,p3)也属于他,反之如果路径(p1,p2)和(p2,p3)都属于小可可,那么路径(p1,p3)也属于小可可。然而这个性质并不总是正确的。所以小薰想知道到底有多少三元组(p1,p2,p3)满足这个性质。
小薰表示她看一眼就知道这道题怎么做了。你会吗?
输入
第一行包含四个整数n,y,k和x,其中n为圣诞树的结点数,y,k和x的含义如题目所示,题目保证y是一个质数。
第二行包含n个整数,第i个整数vi表示第i个结点的权值。
接下来n-1行,每行包含2个整数,表示树上的一条边。树的结点从1到n编号。
输出
包含一个整数,表示有多少整数组(p1,p2,p3)满足题目描述的性质。
样例输入
1 2 1 0
1
样例输出
1
提示
【样例2】
tree.in
tree.out
3 5 2 1
4 3 1
1 2
2 3
14
【样例3】
tree.in
tree.out
8 13 8 12
0 12 7 4 12 0 8 12
1 8
8 4
4 6
6 2
2 3
8 5
2 7
341
【数据规模】
对于20%的数据,n ≤ 200;
对于50%的数据,n ≤ 104;
对于100%的数据,1 ≤ n ≤ 105,2 ≤ y ≤ 109,1 ≤ k ≤ y,0 ≤ x < y。
题解
代码(能开long long 的就开long long 吧)
#include<bits/stdc++.h>
#define ll long long
#define N 100005
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int cnt,rt,tot,sum;
ll x,y,k,n;
int Head[N],ret[2*N],Next[2*N];
int val[N];
int size[N],mx[N],bl[N];
ll in[N],out[N],fac[N],inv[N];
struct node{ll v;int id;}qg[N],qh[N];
bool flag[N];
bool cmp(node a,node b){return a.v<b.v;}
ll gpow(ll x,ll k)
{
if (k==0) return 1;
if (k==1) return x;
ll t=gpow(x,k>>1);
t=t*t%y;
if (k&1) t=t*x%y;
return t;
}
void pre(){
fac[0]=1;
for (int i=1;i<=n;i++)fac[i]=(ll)fac[i-1]*k%y;
inv[n]=gpow(fac[n],y-2);
for (int i=n-1;i;i--)inv[i]=(ll)inv[i+1]*k%y;
}
inline void ins(int u,int v){ret[++tot]=v;Next[tot]=Head[u];Head[u]=tot;}
void getroot(int u,int f)//找重心
{
size[u]=1;mx[u]=0;
for (int i=Head[u];i;i=Next[i])
{
int v=ret[i];
if (!flag[v]&&v!=f)
{
getroot(v,u);
size[u]+=size[v];
mx[u]=max(mx[u],size[v]);
}
}
mx[u]=max(mx[u],sum-size[u]);
if (mx[u]<mx[rt]) rt=u;
}
void dfs(int u,int f,ll h,ll g,int l)//搜索子数,h表示h(x,y),g表示g(x,y),l表示深度
{
qh[++cnt].v=h;qg[cnt].v=(((ll)(x-g)%y+y)%y)*inv[l]%y;
qh[cnt].id=u;qg[cnt].id=u;if (f==rt||u==rt) bl[u]=u;else bl[u]=bl[f];
for (int i=Head[u];i;i=Next[i])
{
int v=ret[i];
if (v!=f&&!flag[v])
dfs(v,u,(h+val[v]*fac[l-1])%y,(g*k+val[v])%y,l+1);
}
}
int queryr(int v)//二分相等的最右端
{
int l=1,r=cnt;
while (l!=r)
{
int mid=(l+r+1)>>1;
if (qg[mid].v>v) r=mid-1;else l=mid;
}
return l;
}
int queryl(int v)//二分相等的最左端
{
int l=1,r=cnt;
while (l!=r)
{
int mid=(l+r)>>1;
if (qg[mid].v<v) l=mid+1;else r=mid;
}
return l;
}
void solve(int u)
{
cnt=0;flag[u]=1;
dfs(u,0,0,val[u],1);
sort(qh+1,qh+cnt+1,cmp);
sort(qg+1,qg+cnt+1,cmp);
for (int i=1;i<=cnt;i++)
{
int l=queryl(qh[i].v);
int r=queryr(qh[i].v);
if (qg[l].v!=qh[i].v) continue;
for (int j=l;j<=r;j++)
if (bl[qh[i].id]!=bl[qg[j].id]||(bl[qh[i].id]==rt&&bl[qg[j].id]==rt))
{
in[qh[i].id]++;
out[qg[j].id]++;
}
}//更新in[]和out[]
for (int i=Head[u];i;i=Next[i])
{
int v=ret[i];
if (!flag[v])
{
rt=0;sum=size[v];
getroot(v,0);
solve(rt);
}
}
}
int main()
{
n=read(),y=read(),k=read(),x=read();
pre();
for (int i=1;i<=n;i++) val[i]=read();
for (int i=1;i<n;i++)
{
int u=read(),v=read();
ins(u,v);ins(v,u);
}
rt=0;sum=n;mx[0]=n;
getroot(1,0);
solve(rt);
ll ans=0;
for (int i=1; i<=n; i++)
ans+=((ll)in[i]*(n-out[i])+(ll)out[i]*(n-in[i])+(ll)2*out[i]*(n-out[i])+(ll)2*in[i]*(n-in[i]));//,cout<<i<<" "<<in[i]<<" "<<out[i]<<endl;
cout<<(ll)n*n*n-ans/2<<endl;
}