题目
Description
Input
从文件 tree.in 中读入数据。
第一行一个正整数 n 表示树的大小。
第二行 n 个正整数表示 vi。
接下来一行 n-1 个正整数,依次表示 2 号结点到 n 号结点,每个结点的父亲编号pi。
Output
输出到文件 tree.out 中。
仅一行一个整数表示答案。
Sample Input
5
5 4 1 2 3
1 1 2 2
Sample Output
12
【样例 1 解释】
value(1) = (5 + 0) ⊕ (4 + 1) ⊕ (1 + 1) ⊕ (2 + 2) ⊕ (3 + 2) = 3。
value(2) = (4 + 0) ⊕ (2 + 1) ⊕ (3 + 1) = 3。
value(3) = (1 + 0) = 1。
value(4) = (2 + 0) = 2。
value(5) = (3 + 0) = 3。
和为 12。
Data Constraint
【数据范围】
10% 的数据:1 ≤ n ≤ 2501。
40% 的数据:1 ≤ n ≤ 152501。
另有 20% 的数据:所有 pi = i-1(2 ≤ i ≤ n)。
另有 20% 的数据:所有 vi = 1(1 ≤ i ≤ n)。
100% 的数据:1 ≤ n, vi ≤ 525010,1 ≤ pi ≤ n。
思路
大体思路是先推式子然后再矩阵树
这里直接推荐一篇写得比较好的博文:https://www.luogu.com.cn/blog/1445353309froggy/solution-p6624
代码
#include<bits/stdc++.h>
#define int long long
#define mp make_pair
#define pii pair<int,int>
using namespace std;
const int N=47,M=152517,mod=998244353;
int n,m,cnt,ans,mx,x[N*N/2],y[N*N/2],w[N*N/2],sum[M],phi[M],vis[M],prime[M];
pii f[N][N];
int power(int x,int t)
{
int b=1;
while(t)
{
if(t&1) b=b*x%mod; x=x*x%mod; t>>=1;
}
return b;
}
pii operator+(const pii x,const pii y) { return mp((x.first+y.first)%mod,(x.second+y.second)%mod); }
pii operator-(const pii x,const pii y) { return mp((x.first-y.first+mod)%mod,(x.second-y.second+mod)%mod); }
pii operator*(const pii x,const pii y) { return mp(x.first*y.first%mod,(x.first*y.second%mod+x.second*y.first%mod)%mod);}
pii operator/(const pii x,const pii y)
{
int inv=power(y.first,mod-2);
return mp(x.first*inv%mod,(x.second*y.first%mod-x.first*y.second%mod+mod)%mod*inv%mod*inv%mod);
}
void pre(int MX)
{
phi[1]=1;
for(int i=2; i<=MX; i++)
{
if(!vis[i])
prime[++cnt]=i,phi[i]=i-1;
for(int j=1; (j<=cnt && i*prime[j]<=MX); j++) {
vis[i*prime[j]]=1;
if(i%prime[j] == 0)
{
phi[i*prime[j]]=phi[i]*prime[j];
break;
} else
phi[i*prime[j]]=phi[i]*(prime[j]-1);
}
}
}
void chk(int val) {
for(int i=1; i<=sqrt(val); i++) {
if(val%i == 0)
sum[i]++,sum[val/i]++;
if(i*i == val)
sum[i]--;
}
}
int calc()
{
pii res=mp(1,0);
for(int i=1; i<=n; i++) {
pii inv=mp(1,0)/f[i][i];
for(int j=1; j<=n; j++) {
if(i == j)
continue;
pii div=f[j][i]*inv;
for(int k=1; k<=n; k++) f[j][k]=f[j][k]-div*f[i][k];
}
}
for(int i=1; i<=n-1; i++) res=res*f[i][i];
return res.second;
}
int solve(int tp) {
for(int i=1; i<=n; i++)
for(int j=1; j<=n; j++) f[i][j]=mp(0,0);
for(int i=1; i<=m; i++) {
if(w[i]%tp)
continue;
f[x[i]][x[i]]=f[x[i]][x[i]]+mp(1,w[i]);
f[y[i]][y[i]]=f[y[i]][y[i]]+mp(1,w[i]);
f[x[i]][y[i]]=f[x[i]][y[i]]-mp(1,w[i]);
f[y[i]][x[i]]=f[y[i]][x[i]]-mp(1,w[i]);
}
return calc();
}
signed main()
{
scanf("%lld%lld",&n,&m);
for(int i=1; i<=m; i++) scanf("%lld%lld%lld",&x[i],&y[i],&w[i]),chk(w[i]),mx=max(mx,w[i]);
pre(mx);
for(int i=1; i<=mx; i++)
{
if(sum[i]<n-1)
continue;
ans=(ans+phi[i]*solve(i)%mod)%mod;
}
printf("%lld",ans);
}