Description
圣诞节到了,小可可送给小薰一棵圣诞树。这棵圣诞树很奇怪,它是一棵多叉树,有n个点,n-1条边。它的每个结点都有一个权值。小可可和小薰想用这棵树玩一个游戏。
定义(s,e)为树上从s到e的简单路径,我们可以记下在这条路径上经过的结点,定义这个结点序列为S(s,e)。
我们按照如下方法定义这个序列S(s,e)的权值G(S(s,e)):假设这个序列中结点的权值为Z0,Z1,…,Z(L-1),其中L为序列的长度,我们定义G(S(s,e))=Z0 × k^0 + Z1 × k^1 + … + Z(L-1) × k^(L-1)。
如果路径(s,e)满足G(S(s,e)) ≡ x (mod y) ,那么这条路径属于小可可,否则这条路径属于小薰。小可可和小薰很显然不希望这个游戏变得那么简单。如果路径(p1,p2)和(p2,p3)都属于小薰,那么路径(p1,p3)也属于他 或 如果路径(p1,p2)和(p2,p3)都属于小可可,那么路径(p1,p3)也属于小可可。然而这个性质并不总是正确的。所以小薰想知道到底有多少三元组(p1,p2,p3)满足这个性质。
对于100%的数据,1 ≤ n ≤ 10^5,2 ≤ y ≤ 10^9,1 ≤ k ≤ y,0 ≤ x < y。
Analysis
对于图中的路径,如果G(s(i,j))mod y=x,则构造边权为1的边(i,j),否则边权为0
那么现在求的是所有满足(i,j) (j,k) (i,k) 权值相等的(i,j,k)三元组个数
但是我们发现,直接求很难(从j入手,点分的时候好像还要倍增判定,又慢又难打)
正难则反
定义in0[i]表示进入i 的边中权值为 0 的个数。类似地定义 in1[i],out0[i],out1[i]
画个图构造出所有的非法三元组,发现共有6种情况
令
p=∑ni=12∗out0[i]∗out1[i]+2∗in0[i]∗in1[i]+out0[i]∗in1[i]+out1[i]∗in0[i]
则我们知道三条边权值不全相同三元组个数被计算了两遍
ans=n3−p/2
则问题变成快速求出in,out,因为in0[i]+in1[i]=n,所以只需求出in1和out1就行了
剩下的思路就是用点分治,求满足等式的点对可以通过哈希记录查询实现。具体细节请读者自行思考
ps:点分治常犯错误:在判断 b[i].S!=b[i−1].S ,且i循环到num+1时,记得清空b[num+1].S
Code
#include<cstdio>
#include<algorithm>
#define fo(i,a,b) for(ll i=a;i<=b;i++)
#define efo(i,v) for(int i=last[v];i;i=next[i])
using namespace std;
typedef long long ll;
const int N=100010,M=N*2;
const ll hx=2000000;
ll n,K,X,mo,_k[N],ny[N],a[N],in[N],out[N],hf[hx][2],hg[hx][2];
ll num,rt,tot,to[M],next[M],last[N],size[N];
bool bz[N];
struct node
{
ll v,f,g,S,l;
}b[N];
bool cmp(node a,node b)
{
return a.S<b.S;
}
void link(int u,int v)
{
to[++tot]=v,next[tot]=last[u],last[u]=tot;
}
ll qmi(ll x,ll n)
{
ll t=1;
for(;n;n>>=1)
{
if(n&1) t=t*x%mo;
x=x*x%mo;
}
return t;
}
void getnum(int v,int fr)
{
num++;
efo(i,v)
{
int u=to[i];
if(u==fr || bz[u]) continue;
getnum(u,v);
}
}
void getrt(int v,int fr)
{
size[v]=1;
efo(i,v)
{
int u=to[i];
if(u==fr || bz[u]) continue;
getrt(u,v);
size[v]+=size[u];
}
if(size[v]>num-size[v] && !rt) rt=v;
}
void dfs(int v,int fr,ll d,ll f,ll g,ll S)
{
b[++num].v=v,b[num].S=S,b[num].l=d,b[num].f=f,b[num].g=g;
efo(i,v)
{
int u=to[i];
if(u==fr || bz[u]) continue;
dfs(u,v,d+1,(f+a[u]*_k[d+1])%mo,(g*K+a[u])%mo,S);
}
}
ll hashf(ll x)
{
ll pos=x%hx;
while(hf[pos][0] && hf[pos][0]!=x) pos=(pos+1)%hx;
return pos;
}
ll hashg(ll x)
{
ll pos=x%hx;
while(hg[pos][0] && hg[pos][0]!=x) pos=(pos+1)%hx;
return pos;
}
void divide(int v,int fr)
{
num=0;getnum(v,fr);
rt=0;getrt(v,fr);
if(a[rt]%mo==X) in[rt]++,out[rt]++;
num=0;
efo(i,rt)
{
int u=to[i];
if(u==fr || bz[u]) continue;
dfs(u,rt,1,a[u]*K%mo,(a[u]+a[rt]*K)%mo,u);
}
fo(i,1,num)
{
if(b[i].g==X) in[rt]++,out[b[i].v]++;
if((a[rt]+b[i].f)%mo==X) out[rt]++,in[b[i].v]++;
b[i].g=(X-b[i].g+mo)%mo*ny[b[i].l]%mo;
}
sort(b+1,b+num+1,cmp);
fo(i,1,num)
{
ll pos=hashf(b[i].f);
hf[pos][0]=b[i].f,hf[pos][1]++;
pos=hashg(b[i].g);
hg[pos][0]=b[i].g,hg[pos][1]++;
}
ll st=1;
b[num+1].S=0;
fo(i,1,num+1)
if(i>1 && b[i].S!=b[i-1].S)
{
fo(j,st,i-1)
{
ll pos=hashf(b[j].f);hf[pos][1]--;
pos=hashg(b[j].g);hg[pos][1]--;
}
fo(j,st,i-1)
{
ll pos=hashf(b[j].g);
out[b[j].v]+=hf[pos][1];
pos=hashg(b[j].f);
in[b[j].v]+=hg[pos][1];
}
fo(j,st,i-1)
{
ll pos=hashf(b[j].f);hf[pos][1]++;
pos=hashg(b[j].g);hg[pos][1]++;
}
st=i;
}
fo(i,1,num)
{
ll pos=hashf(b[i].f);
hf[pos][0]=hf[pos][1]=0;
pos=hashg(b[i].g);
hg[pos][0]=hg[pos][1]=0;
}
bz[rt]=1;
efo(i,rt)
{
int u=to[i];
if(u==fr || bz[u]) continue;
divide(u,rt);
}
}
int main()
{
int u,v;
scanf("%lld %d %lld %lld",&n,&mo,&K,&X);
_k[0]=1;
fo(i,1,n) _k[i]=_k[i-1]*K%mo;
fo(i,0,n) ny[i]=qmi(_k[i],mo-2);
fo(i,1,n) scanf("%lld",&a[i]),a[i]%=mo;
fo(i,1,n-1)
{
scanf("%d %d",&u,&v);
link(u,v),link(v,u);
}
divide(1,0);
ll sum=0;
fo(i,1,n) sum+=2*in[i]*(n-in[i])+in[i]*(n-out[i])+out[i]*(n-in[i])+2*out[i]*(n-out[i]);
printf("%lld\n",n*n*n-sum/2);
return 0;
}