题目描述
YJC最近在学习树的有关知识。今天,他遇到了这么一个概念:最近公共祖先。对于有根树T的两个结点u、v,最近公共祖先LCA(T,u,v)表示一个结点x,满足x是u、v的祖先且x的深度尽可能大。YJC很聪明,他很快就学会了如何求最近公共祖先。他现在想寻找最近公共祖先有什么性质,于是他提出了这样的一个问题:n层的满k叉树T,求对于每一对(i,j)(1≤i,j≤T的点数),LCA(T,i,j)的深度的和是多少。这个数字n层的满k叉树指一棵带标号的有根树,深度为i( 0≤i<n )的点有k^i个,所有深度≠n-1的点都有k个孩子。YJC发现他不会做了,于是他来问你这个问题的答案。这个答案可能很大,你只需要告诉他答案%998244353的值就可以了。
推式子
以下层数均从0开始编号。
设ans[n]表示n层满k叉树的答案。
设size[n]表示n层满k叉树的树大小。
0层有k^0个,1层有k^1……n层有k^n个。
s=∑ni=0ki
ks=∑n+1i=1ki
s=kn+1−1k−1
得到了size的一般公式。
ans[n+1]=(ans[n]+size[n]2)∗k
什么意思呢?原本你有一个n层的,首先复制k份,然后再添加一个父亲变成n+1层的。对于一颗n层的而言,其原本贡献为ans[n],加上一个父亲后,原来两两lca的深度都增加了1,总共增加了size[n]^2,那么对于k颗而言就是(ans[n]+size[n]^2)*k。而对应那些lca为根的,因为根的深度为0,所以可以不讨论。
这个式子好像还不能直接用矩阵乘法算?
我们设size2[n]=size[n]^2
size2[n+1]=size2[n]+size2[n+1]−size2[n]
size2[n+1]=size2[n]+k2n+4+k2n+2−2∗kn+2+2∗kn+1(k−1)2
那么我们再设
a[n]=k2n+4,b[n]=k2n+2,c[n]=2∗kn+2,d[n]=2∗kn+1
这样ans、size2、a、b、c、d都存在一次递推式,而且是可以矩阵乘法的,于是上吧!
#include<cstdio>
#include<algorithm>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
typedef long long ll;
const int mo=998244353;
int a[7][7],o[7][7],dis[7][7],b[7][7],c[7][7],sta[100];
int i,j,k,l,t,n,m,top;
int quicksortmi(int x,int y){
if (!y) return 1;
int t=quicksortmi(x,y/2);
t=(ll)t*t%mo;
if (y%2) t=(ll)t*x%mo;
return t;
}
int main(){
freopen("lca.in","r",stdin);freopen("lca.out","w",stdout);
scanf("%d%d",&n,&k);
n--;
t=quicksortmi((ll)(k-1)*(k-1)%mo,mo-2);
a[1][1]=0;a[1][2]=1;
a[1][3]=(ll)k*k%mo*k%mo*k%mo;a[1][4]=(ll)k*k%mo;
a[1][5]=(ll)2*k%mo*k%mo;a[1][6]=(ll)2*k%mo;
b[1][1]=b[2][1]=b[5][5]=b[6][6]=k;
b[3][3]=b[4][4]=(ll)k*k%mo;
b[3][2]=b[6][2]=t;
b[4][2]=b[5][2]=-t;
b[2][2]=1;
fo(i,1,6) dis[i][i]=1;
while (n){
sta[++top]=n%2;
n/=2;
}
while (top){
fo(i,1,6)
fo(j,1,6)
o[i][j]=0;
fo(l,1,6)
fo(i,1,6)
fo(j,1,6)
o[i][j]=(o[i][j]+(ll)dis[i][l]*dis[l][j]%mo)%mo;
fo(i,1,6)
fo(j,1,6)
dis[i][j]=o[i][j];
if (sta[top]){
fo(i,1,6)
fo(j,1,6)
o[i][j]=0;
fo(l,1,6)
fo(i,1,6)
fo(j,1,6)
o[i][j]=(o[i][j]+(ll)dis[i][l]*b[l][j]%mo)%mo;
fo(i,1,6)
fo(j,1,6)
dis[i][j]=o[i][j];
}
top--;
}
fo(i,1,6)
fo(j,1,6)
o[i][j]=0;
fo(l,1,6)
fo(i,1,6)
fo(j,1,6)
o[i][j]=(o[i][j]+(ll)a[i][l]*dis[l][j]%mo)%mo;
fo(i,1,6)
fo(j,1,6)
c[i][j]=o[i][j];
(c[1][1]+=mo)%=mo;
printf("%d\n",c[1][1]);
}