题面
题解
首先题目的要求肯定要转化。
有多种转化方法,例如:(其中 s i s_i si 代表以 i i i 为根节点的子树大小)
E ( w ( T ) ) = E ( ∑ i = 1 n ∑ j = 1 n dist T ( i , j ) ) = E ( ∑ i = 2 n s i ( n − s i ) ) = E ( ∑ i = 2 n n s i − s i 2 ) = n ∑ i = 2 n E ( s i ) − ∑ i = 2 n E ( s i 2 ) \begin{aligned} \Epsilon(w(T))=&\Epsilon\left(\sum_{i=1}^{n}\sum_{j=1}^{n}\operatorname{dist}_T(i,j)\right)\\ =&\Epsilon\left(\sum_{i=2}^n s_i\left(n-s_i\right)\right)\\ =&\Epsilon\left(\sum_{i=2}^n ns_i-s_i^2\right)\\ =&n\sum_{i=2}^n\Epsilon(s_i)-\sum_{i=2}^n\Epsilon(s_i^2) \end{aligned} E(w(T))====E(i=1∑nj=1∑ndistT(i,j))E(i=2∑nsi(n−si))E(i=2∑nnsi−si2)ni=2∑nE(si)−i=2∑nE(si2)
这种做法实际上是对于每一条边统计这条边的贡献。
注意 E ( s i 2 ) ≠ E ( s i ) 2 \Epsilon(s_i^2)\neq \Epsilon(s_i)^2 E(si2)=E(si)2。而且除了 a , b a,b a,b 相互独立互不干扰的情况以外,都有 E ( a b ) ≠ E ( a ) E ( b ) \Epsilon(ab)\neq \Epsilon(a)\Epsilon(b) E(ab)=E(a)E(b)。
然后你发现 E ( s i 2 ) \Epsilon(s_i^2) E(si2) 并不好求。
(其实是可以求的,而且有人过了,但我不会)
所以我们考虑另一种转化方法:(其中 d i d_i di 代表 i i i 节点的深度)
E ( w ( T ) ) = E ( ∑ i = 1 n ∑ j = 1 n dist T ( i , j ) ) = E ( ∑ i = 1 n ∑ j = 1 n d i + d j − 2 × d lca ( i , j ) ) = ∑ i = 1 n ∑ j = 1 n E ( d i ) + E ( d j ) − 2 ∑ i = 1 n ∑ j = 1 n E ( d lca ( i , j ) ) = 2 ∑ i = 1 n E ( d i ) − 2 ∑ u = 1 n E ( d u ) ∑ i = 1 n ∑ j = 1 n P [ lca ( i , j ) = u ] \begin{aligned} \Epsilon(w(T))=&\Epsilon\left(\sum_{i=1}^{n}\sum_{j=1}^{n}\operatorname{dist}_T(i,j)\right)\\ =&\Epsilon\left(\sum_{i=1}^{n}\sum_{j=1}^{n}d_i+d_j-2\times d_{\operatorname{lca}(i,j)}\right)\\ =&\sum_{i=1}^n\sum_{j=1}^n\Epsilon(d_i)+\Epsilon(d_j)-2\sum_{i=1}^n\sum_{j=1}^n\Epsilon(d_{\operatorname{lca}(i,j)})\\ =&2\sum_{i=1}^n\Epsilon(d_i)-2\sum_{u=1}^n\Epsilon(d_u)\sum_{i=1}^n\sum_{j=1}^nP\big[\operatorname{lca}(i,j)=u\big] \end{aligned} E(w(T))====E(i=1∑nj=1∑ndistT(i,j))E(i=1∑nj=1∑ndi+dj−2×dlca(i,j))i=1∑nj=1∑nE(di)+E(dj)−2i=1∑nj=1∑nE(dlca(i,j))2i=1∑nE(di)−2u=1∑nE(du)i=1∑nj=1∑nP[lca(i,j)=u]
这种做法实际上是对于每一个点对统计贡献。
注意到 E ( d i ) \Epsilon(d_i) E(di) 是可以通过 DP 容易得到的,那么只需考虑 ∑ i = 1 n ∑ j = 1 n P [ lca ( i , j ) = u ] \sum\limits_{i=1}^n\sum\limits_{j=1}^nP\big[\operatorname{lca}(i,j)=u\big] i=1∑nj=1∑nP[lca(i,j)=u]。
设 u u u 的可能的后代的集合为 p o s u pos_u posu,设 v v v 成为 u u u 的后代的概率为 p u , v p_{u,v} pu,v。这些都可以预处理出来。
设 f u = ∑ i = 1 n ∑ j = 1 n P [ lca ( i , j ) = u ] f_u=\sum\limits_{i=1}^n\sum\limits_{j=1}^nP\big[\operatorname{lca}(i,j)=u\big] fu=i=1∑nj=1∑nP[lca(i,j)=u],那么最自然的想法是这样容斥:
f u = E ( s u 2 ) − ∑ v ∈ p o s u p u , v f v f_u=\Epsilon(s_u^2)-\sum_{v\in pos_u}p_{u,v}f_v fu=E(su2)−v∈posu∑pu,vfv
但你发现这样好像还是要把 E ( s u 2 ) \Epsilon(s_u^2) E(su2) 算出来(
怎么办?
有另一种奇怪的计算方式:
考虑按照题目的方法生成两棵树 T 1 T_1 T1 和 T 2 T_2 T2。 T 1 T_1 T1 和 T 2 T_2 T2 无需满足任何关系,它们是相互独立的。
考虑在 T 1 T_1 T1 中选出 u u u 子树内的一个点 x x x、在 T 2 T_2 T2 中选出 u u u 子树内的一个点 y y y 的 ( x , y ) (x,y) (x,y) 的方案数的期望,显然为 E ( s u ) 2 \Epsilon(s_u)^2 E(su)2。
设 g u g_u gu 表示:
在 T 1 T_1 T1 中选出 u u u 子树内的一个点 x x x、在 T 2 T_2 T2 中选出 u u u 子树内的一个点 y y y、
且需要满足 T 1 T_1 T1 中 x → u x\to u x→u 路径上的点和 T 2 T_2 T2 中 y → u y\to u y→u 路径上的点除了 u u u 之外互不重复
的 ( x , y ) (x,y) (x,y) 的方案数的期望。
那么有转移:
g u = E ( s u ) 2 − ∑ v ∈ p o s u p u , v 2 g v g_u=\Epsilon(s_u)^2-\sum_{v\in pos_u}p_{u,v}^2 g_v gu=E(su)2−v∈posu∑pu,v2gv
其中:
-
E ( s u ) 2 \Epsilon(s_u)^2 E(su)2 表示在 T 1 T_1 T1 中选出 u u u 子树内的一个点 x x x、在 T 2 T_2 T2 中选出 u u u 子树内的一个点 y y y 的 ( x , y ) (x,y) (x,y) 的方案数的期望。显然这样任意选取并不能保证上面说的两条路径上的点除了 u u u 之外不重复,所以要减去一些不合法的方案。
-
∑ v ∈ p o s u \sum\limits_{v\in pos_u} v∈posu∑ 枚举的是 T 1 T_1 T1 中 x → u x\to u x→u 路径上的点和 T 2 T_2 T2 中 y → u y\to u y→u 路径上的点从 v v v 开始重复了。(即 T 1 T_1 T1 中 x → v x\to v x→v 路径上的点和 T 2 T_2 T2 中 y → v y\to v y→v 路径上的点除了 v v v 之外都还没出现重复)
-
p u , v 2 p_{u,v}^2 pu,v2 代表 v v v 在 T 1 T_1 T1、 T 2 T_2 T2 中都是 u u u 的后代的概率,显然这样就能保证 x x x、 y y y 在 T 1 T_1 T1、 T 2 T_2 T2 中都分别是 u u u 的后代。
-
g v g_v gv 代表枚举的这种情况的方案数的期望。
然后你发现 f u f_u fu 和 g u g_u gu 其实是一个东西。
然后求出来即可。
#include<bits/stdc++.h>
#define N 300010
using namespace std;
namespace modular
{
int mod;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
}using namespace modular;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
inline int poww(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
int n,d[N],inv[N],p[N],f[N];
int ans;
vector<int>pre[N];
int main()
{
n=read();
mod=read();
for(int i=1;i<=n;i++)
for(int j=i+i;j<=n;j+=i)
pre[j].push_back(i);
for(int i=2;i<=n;i++)
{
int size=pre[i].size();
for(int j=0;j<size;j++)
d[i]=add(d[i],add(d[pre[i][j]],1));
inv[i]=poww(size,mod-2);
d[i]=mul(d[i],inv[i]);
ans=add(ans,mul(2,mul(n,d[i])));
}
for(int i=n;i>=1;i--)
{
int m=n/i;
p[1]=1;
int sum=p[1];
for(int j=2;j<=m;j++)
{
p[j]=0;
for(int k=0,size=pre[j].size();k<size;k++)
p[j]=add(p[j],p[pre[j][k]]);
p[j]=mul(p[j],inv[i*j]);
sum=add(sum,p[j]);
}
f[i]=mul(sum,sum);
for(int j=2;j<=m;j++)
f[i]=dec(f[i],mul(mul(p[j],p[j]),f[i*j]));
ans=dec(ans,mul(mul(2,d[i]),f[i]));
}
printf("%d\n",ans);
return 0;
}
/*
3 998244353
*/
/*
5 998244353
*/