题目
Description
提督们惊奇地发现,2019 年实装的改造非常少。
经调查,原来是改造厂的厂长于八日克扣了其他舰娘改造的图纸,并且在 2020 年的第一个月利用这些图纸进行了华丽的改造,一共有三种形态,于八日改二,于八日改二特,于八日改二丁,对空、对陆、
对潜、开幕雷、五装备格,无所不能。镇守府雪菜八万钢惨遭退役。
舰娘的结构可以看成一棵 n 个点的树,点编号为 0 ∼ n−1。使用一张图纸可以把树中的某一条边去掉,再加上一条边,使得它依然是一棵树。
现在,于八日想在 2020 年继续拿走别的舰娘的图纸对自己进行改造,她一共拿走了 k 张图纸。她想知道,自己经过接下来的改造之后,总共会有多少种形态。两个形态不同,表示有一条边 (x,y) 在第一棵
树中出现,在另一棵树中不出现。
Input
第一行两个整数 n,k,表示树的结点数和图纸数。
第二行 n − 1 个整数 fi 描述树的形态,表示编号为 i 的结点父亲为 fi 。
Output
一行一个整数表示答案,%998244353 输出。
Sample Input
样例输入1:
3 1
0 0
样例输入2:
4 1
0 1 2
样例输入3:
6 1
0 1 2 2 0
Sample Output
样例输出1:
3
样例输出2:
8
样例输出3:
28
Data Constraint
对于所有数据,满足 1 ≤ n ≤ 50 , 0 ≤ k ≤ n。
对于 20% 的数据,k = 0。其中测试点 1,k = 0。
对于 20% 的数据 (测试点 1 ∼ 4),fi =0。
对于 30% 的数据 (测试点 1 ∼ 6),n ≤ 5。
思路
生成树计数?上矩阵树定理
定义一条边 (u,v) 的权值:若 (u,v) 在原树中,权值为 1,否则为 x。
我们要求的就是边权之积 ≤xk 的生成树个数。
显然,x0……xk的系数之和即为答案
我们可以用拉格朗日插值法解决
代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=155,mod=998244353;
bool bz[N][N];
int a[N][N],inv[N],ans,n,m,c[N],d[N],h[N];
void add(int &x,int y)
{
(x+=y)>=mod&&(x-=mod);
}
void del(int &x,int y)
{
(x-=y)<0&&(x+=mod);
}
int mul(int x,int y)
{
return (ll)x*y%mod;
}
int power(int x,int y)
{
int yjy = 1;
while(y)
{
if(y&1) yjy=mul(yjy,x);
y>>=1; x=mul(x,x);
}
return yjy;
}
int gauss(int n)
{
int yjy=1;
for(int i=1; i<n; i++)
{
int x=0;
for(int j=i; j<=n; j++)
if(a[j][i])
{
x=j; break;
}
if(!x) continue;
if(x!=i) yjy=mod-yjy,swap(a[x],a[i]);
int inv=power(a[i][i],mod-2);
for(int j=i+1; j<=n; j++)
{
int v=mul(a[j][i],inv);
for(int k=i; k<=n; k++) del(a[j][k],mul(a[i][k],v));
}
}
for(int i=1; i<=n; i++) yjy=mul(yjy,a[i][i]);
return yjy;
}
int calc(int bas)
{
memset(a,0,sizeof(a));
for(int i=1; i<n; i++) for(int j=i+1; j<=n; j++)
{
int v=1;
if(!bz[i][j]) v=bas;
add(a[i][i],v); add(a[j][j],v);
del(a[i][j],v); del(a[j][i],v);
}
return gauss(n-1);
}
int main()
{
freopen("kaisou.in","r",stdin); freopen("kaisou.out","w",stdout);
scanf("%d%d",&n,&m);
inv[1]=1;
for(int i=2; i<=n; i++)
{
int x;
scanf("%d",&x); x++;
bz[x][i]=bz[i][x]=1;
inv[i]=mul(mod-mod/i,inv[mod%i]);
}
for(int i=1; i<=n; i++)
{
int y=calc(i);
for(int j=1; j<=n; j++) c[j]=0;
c[0]=1;
for(int j=1; j<=n; j++)
if(j!=i)
{
if(i<j) y=mul(y,mod-inv[j-i]);
else y=mul(y,inv[i-j]);
for(int k=0; k<=n; k++)
{
h[k]=mul(mod-j,c[k]);
if(k) add(h[k],c[k-1]);
}
for(int k=0; k<=n; k++) c[k]=h[k];
}
for(int j=0; j<=n; j++) add(d[j],mul(c[j],y));
}
for(int i=0; i<=m; i++) add(ans,d[i]);
printf("%d\n",ans);
}