题目大意:
给定一棵有
n
n
个节点的树,每个节点有点权给出
p,A,B
p
,
A
,
B
,问有多少点对
(u,v)
(
u
,
v
)
满足:
1.
v
v
是的祖先。
2.
a2u+Aauav+Ba2v≡0(modp)
a
u
2
+
A
a
u
a
v
+
B
a
v
2
≡
0
(
mod
p
)
n≤100000,p∈P,3≤p≤1016,0≤A,B<p n ≤ 100000 , p ∈ P , 3 ≤ p ≤ 10 16 , 0 ≤ A , B < p
解题思路:
考场上不会二次剩余怎么解,还好暴力有55分……
回来再这里学习了一下解法,好像确实是板题。
用解一元二次方程的方法可得
au≡−A±A2−4B√2av(modp)
a
u
≡
−
A
±
A
2
−
4
B
2
a
v
(
mod
p
)
若
A2−4B
A
2
−
4
B
为二次剩余,则用Cipolla算法直接计算
A2−4B‾‾‾‾‾‾‾‾√
A
2
−
4
B
。
否则答案统计0的对数。
算出
A2−4B‾‾‾‾‾‾‾‾√
A
2
−
4
B
后,直接dfs时用map统计即可。
时间复杂度为
O(nlogp)
O
(
n
l
o
g
p
)
。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
ll getint()
{
ll i=0,f=1;char c;
for(c=getchar();(c!='-')&&(c<'0'||c>'9');c=getchar());
if(c=='-')c=getchar(),f=-1;
for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
return i*f;
}
const int N=100005;
int n,num,tot,first[N],nxt[N],to[N];
ll p,A,B,ans,a1,a2,w,det,a[N];
map<ll,int>cnt;
ll mul(ll x,ll y)
{
ll res=0;
for(;y;y>>=1,x=(x+x)%p)
if(y&1)res=(res+x)%p;
return res;
}
ll Pow(ll x,ll y)
{
ll res=1;
for(;y;y>>=1,x=mul(x,x))
if(y&1)res=mul(res,x);
return res;
}
struct Complex
{
ll x,y;
Complex(){}
Complex(ll _x,ll _y):x(_x),y(_y){}
inline friend Complex operator * (const Complex &a,const Complex &b)
{return Complex((mul(a.x,b.x)+mul(mul(a.y,b.y),w))%p,(mul(a.x,b.y)+mul(a.y,b.x))%p);}
inline friend Complex Pow(Complex &a,ll b)
{
Complex res=Complex(1,0);
for(;b;b>>=1,a=a*a)
if(b&1)res=res*a;
return res;
}
};
ll find(ll n)
{
if(!n)return 0;
ll a;
while(1)
{
a=rand()%p;
w=(mul(a,a)-n+p)%p;
if(Pow(w,(p-1)/2)==p-1)break;
}
Complex res=Complex(a,1);
res=Pow(res,(p+1)/2);
return res.x;
}
void add(int x,int y)
{
nxt[++tot]=first[x],first[x]=tot,to[tot]=y;
}
void dfs1(int u)
{
ans+=cnt[a[u]];
ll v1=mul(a1,a[u]),v2=mul(a2,a[u]);
v1==v2?cnt[v1]++:(cnt[v1]++,cnt[v2]++);
for(int e=first[u];e;e=nxt[e])
dfs1(to[e]);
v1==v2?cnt[v1]--:(cnt[v1]--,cnt[v2]--);
}
void dfs2(int u)
{
if(a[u]==0)ans+=num,num++;
for(int e=first[u];e;e=nxt[e])
dfs2(to[e]);
if(a[u]==0)num--;
}
int main()
{
//freopen("lx.in","r",stdin);
//freopen("lx.out","w",stdout);
n=getint(),p=getint(),A=getint(),B=getint();
for(int i=1;i<=n;i++)a[i]=getint();
for(int i=2;i<=n;i++)add(getint(),i);
det=(mul(A,A)-B*4%p+p)%p;
if(Pow(det,(p-1)/2)!=p-1)
{
det=find(det);
a1=mul((-A+det+p)%p,Pow(2,p-2));
a2=mul((-A-det+p+p)%p,Pow(2,p-2));
dfs1(1);
}
else dfs2(1);
cout<<ans<<'\n';
return 0;
}