原题面
圣诞节到了,小可可送给小薰一棵圣诞树。这棵圣诞树很奇怪,它是一棵多叉树,有n个点,n-1条边。它的每个结点都有一个权值。小可可和小薰想用这棵树玩一个游戏。
定义(s,e)为树上从s到e的简单路径,我们可以记下在这条路径上经过的结点,定义这个结点序列为S(s,e)。
我们按照如下方法定义这个序列S(s,e)的权值G(S(s,e)):假设这个序列中结点的权值为Z0,Z1,…,Z(L-1),其中L为序列的长度,我们定义G(S(s,e))=Z0 × k^0 + Z1 × k^1 + … + Z(L-1) × k^(L-1)。
如果路径(s,e)满足G(S(s,e)) ≡ x (mod y) ,那么这条路径属于小可可,否则这条路径属于小薰。小可可和小薰很显然不希望这个游戏变得那么简单。小薰认为如果路径(p1,p2)和(p2,p3)都属于他,那么路径(p1,p3)也属于他,反之如果路径(p1,p2)和(p2,p3)都属于小可可,那么路径(p1,p3)也属于小可可。然而这个性质并不总是正确的。所以小薰想知道到底有多少有序三元组(p1,p2,p3)满足这个性质。
小薰表示她看一眼就知道这道题怎么做了。你会吗?
对于100%的数据, 1≤n≤105,2≤y≤109,1≤k≤y,0≤x<y 。
大致题意
给定一棵树与k,求树中(i,j)与(j,k)与(i,k)都满足或都不满足 g(l,r)=x(mody)
g(l,r)=Z0×k0+Z1×k1+…+Z(L−1)×k(L−1))
每个点都有一个权值V[x],Z[i]表示的是L到R的简单路径上的点
思路
首先,先抽象问题。
我们对每个无序点对(x,y)连一条边,当这个g(x,y)满足条件时,边权为1,否则为0。
那么问题实际上就变为了找由有向边(i,j) (j,k) (i,k)组成的”有向三角形”,且满足(i,j)+(j,k)+(i,k)=3或0
这样的话我们要确定三条边的权值,
n2
做不了,所以我们正难则反,求不满足三角形的个数。 只需要确定两条不等的边就可以确定一个三角形 (但是对于每一个不满足三角形来说会算两次)
设一个点的出边为0的个数为out0…等等….
考虑每一个不满足三角形的三点 就有不满足三角形个数为
(注意是有序三元组,所以要*2),
然后
Ans=n3−p/2
这样我们就得到了一个 n2 的算法。 我们再考虑点分治求in与out.
我们考虑两个点i,j,如果满足下列式子则out[i]++,in[j]++
g(i,j)=g(i,k)+h(k,j)∗KLen(i,k)=x(mody)
其中h(x,y)表示x,y的路径(不包括x)的权值
将式子变形,不难得到
h(k,j)=x−g(i,k)KLen(i,k)
发现如果有一个确定的分治根作为k,那么子树内节点可以直接dfs一遍求这两个值,然后排序后O(n)求相等对数。 由于是点分治,所以实际复杂度应该是O(n (log n)^2)
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#define mxn 100010
#define max(a,b) ((a)>(b)?(a):(b))
typedef long long ll;
using namespace std;
ll tot,head[mxn],to[mxn*2],next[mxn*2];
ll dep[mxn],size[mxn];
ll n,y,k,xx,mi;
ll v[mxn],pw[mxn],rp[mxn],ans,curFr;
ll vis[mxn],cur,h[2][mxn];
ll in[mxn],out[mxn];
struct node{
ll v,no;
node() {}
node (int vv,int noo) {v=vv; no=noo;}
} dat[2][mxn];
bool cmp(node a,node b) {return a.v<b.v;}
ll ksm(ll a,ll b) {
if (b==0) return 1;
if (b==1) return a;
ll tmp=ksm(a,b>>1);
return tmp*tmp%y*ksm(a,b&1)%y;
}
void link(int x,int y){
to[++tot]=y;
next[tot]=head[x];
head[x]=tot;
}
void preWork() {
pw[0]=1;
for (int i=1; i<=n; i++) pw[i]=pw[i-1]*k%y;
rp[n]=ksm(pw[n],y-2);
for (int j=n-1; j; j--) rp[j]=rp[j+1]*k%y;
}
int dfs(int nw,int fa){
size[nw]=1;
for (int k=head[nw]; k; k=next[k]) {
if (vis[to[k]]==0 && to[k]!=fa) {
dfs(to[k],nw);
size[nw]+=size[to[k]];
}
}
}
void core(int nw,int fa,int &cr){
int tmp=size[curFr]-size[nw];
for (int k=head[nw]; k; k=next[k]) {
if (vis[to[k]]==0 && to[k]!=fa) {
core(to[k],nw,cr);
tmp=max(tmp,size[to[k]]);
}
}
if (tmp<mi) {mi=tmp; cr=nw;}
}
void putTag(int x,int fa,int de) {
dep[x]=de;
h[0][x]=(h[0][fa]*k+v[x])%y;
h[1][x]=(h[1][fa]+v[x]*pw[dep[x]])%y;
for (int k=head[x]; k; k=next[k])
if (vis[to[k]]==0 && to[k]!=fa) putTag(to[k],x,de+1);
dat[0][++cur]=node((xx-h[0][x]+y)*rp[dep[x]]%y,x);
dat[1][cur]=node(h[1][x],x);
}
void cnt(int f,int st,int ed) {
sort(dat[0]+st,dat[0]+1+ed,cmp);
sort(dat[1]+st,dat[1]+1+ed,cmp);
int st1=st,st2=st-1,ed1=st,ed2=st-1;
while (st1<=ed) {
while (ed1<ed && dat[0][ed1+1].v==dat[0][st1].v) ed1++;
while (st2<=ed && dat[1][st2].v<dat[0][st1].v || st2==st-1) st2++;
ed2=st2;
if (st2>ed) break;
while (ed2<ed && dat[1][st2].v==dat[1][ed2+1].v) ed2++;
if (dat[0][st1].v==dat[1][st2].v) {
for (int p=st1; p<=ed1; p++) out[dat[0][p].no]+=f*(ed2-st2+1);
for (int p=st2; p<=ed2; p++) in[dat[1][p].no]+=f*(ed1-st1+1);
}
st1=ed1+1;
}
}
void divide(int fr) {
curFr=fr; mi=99999999;
int c,last;
dfs(fr,0); core(fr,0,c);
cur=0;//memset(dat,0,sizeof dat);
h[0][c]=v[c]; h[1][c]=0;
for (int k=head[c]; k; k=next[k]) {
if (vis[to[k]]==0) {
last=cur+1;
putTag(to[k],c,1);
cnt(-1,last,cur);
}
}
dat[0][++cur]=node((xx-h[0][c]+y)%y,c);
dat[1][cur]=node(h[1][c],c);
cnt(1,1,cur);
vis[c]=1;
for (int k=head[c]; k; k=next[k]) {
if(vis[to[k]]==0) divide(to[k]);
}
}
int main() {
cin>>n>>y>>k>>xx;
for (int i=1; i<=n; i++) scanf("%lld",&v[i]);
ll a,b;
for (int i=1; i<n; i++) {
scanf("%lld %lld",&a,&b);
link(a,b); link(b,a);
}
preWork();
divide(1);
ll cnt=0;
for (int i=1; i<=n; i++)
cnt+= (in[i]*(n-out[i])+out[i]*(n-in[i])+2*out[i]*(n-out[i])+2*in[i]*(n-in[i]));
cout<<n*n*n-cnt/2<<endl;
}