Operating on the Tree
题目描述:
此问题是由问题G(Operating on a Graph )启发的。 因此,您需要阅读它的声明才能解决此问题。Operating on a Graph题目描述+题解
您将得到一棵具有
n
n
n个顶点的树。 假设
p
p
p是从
0
0
0到
n
−
1
n-1
n−1的排列。 我们定义函数
f
(
p
)
f(p)
f(p)如下:假设给定的树是问题G的输入图,而
p
p
p是输入运算符序列。
f
(
p
)
f(p)
f(p)是满足条件的操作数:执行第
i
i
i个操作时,至少有一个顶点属于
O
i
O_i
Oi组。令
S
S
S为从
0
0
0到
n
−
1
n-1
n−1的所有可能排列的集合。 请计算(
∑
\sum
∑ p
∈
\in
∈ S
f
(
p
)
f(p)
f(p))
m
o
d
\ mod
mod
998244353
998244353
998244353
输入描述:
第一行包含一个整数 t t t ( ( ( 1 1 1 ≤ \le ≤ t ≤ \le ≤ 500 500 500 ) ) ) 表示测试用例的数量,每个测试包含两行。 第一行包含一个整数n,代表给定树中的顶点数 ( ( ( 1 1 1 ≤ \le ≤ n n n ≤ \le ≤ 2000 2000 2000 ) ) )。 第二行包含 n − 1 n-1 n−1个非负整数 a 1 a_1 a1, a 2 a_2 a2, … \ldots …, a n − 1 a_ {n-1} an−1。 它表示树的第 i i i个边缘连接顶点 i i i和 a i a_i ai ( ( ( a i a_i ai < < < i i i ) ) )所有测试用例的 n n n之和不超过 2000 2000 2000
输出描述:
对于每个测试,输出一行,其中包含一个表示答案的整数,范围是 [ 0 , 998244352 ] [0,998244352] [0,998244352]。
样例输入:
3
4
0 1 2
4
0 1 1
2
0
样例输出:
48
60
2
思路:
树形DP
根据题意,我们可以知道
1.没有两个好点是相邻的
2.每个坏点都至少与一个比他大的好点相邻
看到这里,聪明的你一定已经想到用什么方法解着道题了吧
那就是树形DP
dp数组开三维:dp[MAXN][3][MAXN]:
第一维表示当前节点,第二维表示成功/失败/尚未失败,第三维表示子儿子中有几个成功数
这样我们就可以分三类讨论:
树根是好点,坏点但尚未有比它大的好点相邻,坏点已有比它大的好点相邻
在此基础上我们又可以分三类讨论:
树根是好点,坏点但尚未有比它大的好点相邻,坏点已有比它大的好点相邻
这样就要分九种情况分类讨论 啊啊啊啊好烦!
大体就是这样的,具体细节详见代码注释
AC Code:
#include<bits/stdc++.h>
using namespace std;
const int MAXN=2e3+5;
const int mod=998244353;
vector<int> e[MAXN];
int comb[MAXN][MAXN],sz[MAXN],dp[MAXN][3][MAXN],dp1[MAXN][3][MAXN],tmp[3][MAXN],tmp1[3][MAXN];
void add(int &u,int v)
{
u+=v;
u-=(u>=mod?mod:0);
}
void dfs(int u)
{
sz[u]=dp[u][0][0]=dp[u][2][0]=1;
for(int vec=0,v;vec<e[u].size();vec++)
{
v=e[u][vec];
dfs(v);
for(int i=0;i<3;i++)
for(int j=1;j<sz[v];j++)
{
add(dp[v][i][j],dp[v][i][j-1]);
add(dp1[v][i][j],dp1[v][i][j-1]);
}//更新dp的值为前缀和,便于后续计算。注:此时dp含义已经变化,dp[i][sta][j]变成了最多j个节点比i大时的情况
for(int i=0;i<sz[u];i++)//枚举v之前的子树中,比x大的方案数
for(int j=0;j<=sz[v];j++)
{//枚举v子树中,比x大的方案数
int coe=1ll*comb[i+j][i]*comb[sz[u]-1-i+sz[v]-j][sz[v]-j]%mod;
//恰好共有i+j个节点比u大,且其中j个节点属于v子树的方案数
//=比u大的i+j个点,有i个点是v之前的子树中的方案数 * 比x小的sz[u]-1-i+sz[v]-j个点中,有sz[v]-j个点在v子树中的方案。
for(int type1=0;type1<3;type1++)
for(int type2=0;type2<3;type2++)
{
// v节点在u之前的情况,即比u大的i+j个节点中,最多有j-1个节点属于v节点的情况
int cnt=j?dp[v][type2][j-1]:0;
//v子树中最多有j-1个点比v大的方案数
int cnt1=j?dp1[v][type2][j-1]:0;
//v子树中,最多有j-1个点比v大时的v子树的贡献
int coe1=1ll*coe*dp[u][type1][i]%mod*cnt%mod;
//v节点比u节大的方案数
int base=coe*(1ll*dp[u][type1][i]*cnt1%mod+1ll*dp1[u][type1][i]*cnt%mod)%mod;
//v节点比u节点大时,u节点和v子树的贡献
if(!type1)
{
if(type2==1)
{
add(tmp[0][i+j],coe1);
add(tmp1[0][i+j],base);
}//u好和v坏的状态更新到u好
}
else if(type1==1)
{
if(!type2||type2==1)
{
add(tmp[1][i+j],coe1);
add(tmp1[1][i+j],base);
}//u坏和v好/坏的状态,更新到u坏
}
else if(type1==2)
{
if(!type2)
{
add(tmp[1][i+j],coe1);
add(tmp1[1][i+j],base);//u半坏和v好的状态,更新到u坏
}
else if(type2==1)
{
add(tmp[2][i+j],coe1);
add(tmp1[2][i+j],base);
}//u半坏和v坏的状态,更新到u半坏
}//v节点在u之后的情况,即比u大的i+j个节点中,至少有j个节点属于v节点的情况,与上一种情况对立
cnt=dp[v][type2][sz[v]-1]-cnt;
cnt+=cnt<0? mod:0;
//v子树中至少有j个节点比v大的方案数
cnt1=dp1[v][type2][sz[v]-1]-cnt1;
cnt1+=cnt1<0? mod:0;
//v子树中,至少有j个点比v大时的v子树的贡献
coe1=1ll*coe*dp[u][type1][i]%mod*cnt%mod;
//v节点比u节小的方案数
base=coe*(1ll*dp[u][type1][i]*cnt1%mod+1ll*dp1[u][type1][i]*cnt%mod)%mod;
//v节点比u节点小时,u节点和v子树的贡献
if(!type1)
{
if(type2==1||type2==2)
{
add(tmp[0][i+j],coe1);
add(tmp1[0][i+j],base);
}//u好和v坏/半坏,更新到u好
}
else if(type1==1)
{
if(!type2||type2==1)
{
add(tmp[1][i+j],coe1);
add(tmp1[1][i+j],base);
}//u坏和v好/坏,更新到u坏
}
else if(type1==2)
{
if(!type2||type2==1)
{
//u半坏和v好/坏的状态,更新到u半坏。
//因为v是在u之后,所以在u半坏是因为u的父亲在u之前导致的,u之后是允许v好的
add(tmp[2][i+j],coe1);
add(tmp1[2][i+j],base);
}
}
}
}
sz[u]+=sz[v];
for(int i=0;i<sz[u];i++)
for(int j=0;j<3;j++)
{
dp[u][j][i]=tmp[j][i];
dp1[u][j][i]=tmp1[j][i];
tmp[j][i]=tmp1[j][i]=0;
}//拷贝到dp上
}
for(int i=0;i<sz[u];i++)
add(dp1[u][0][i],dp[u][0][i]);//加上u自己的贡献
}
int n,t,ans;
int main()
{
for(int i=0;i<MAXN;i++)
{
comb[i][0]=1;
for(int j=1;j<=i;j++)
comb[i][j]=(comb[i-1][j-1]+comb[i-1][j])%mod;//杨辉三角算组合数
}
scanf("%d",&t);
while(t--)
{
scanf("%d",&n);
for(int i=0;i<=n;i++)
{
e[i].clear();
memset(dp[i],0,sizeof(dp[i]));
memset(dp1[i],0,sizeof(dp1[i]));
}
for(int i=2,fa;i<=n;i++)
{
scanf("%d",&fa);
fa++;e[fa].push_back(i);
}
dfs(1);
ans=0;
for(int i=0;i<n;i++)
{
add(ans,dp1[1][0][i]);
add(ans,dp1[1][1][i]);
//取模加法,由于用减法替代除法取模,因此会算得更快
}
printf("%d\n",ans);
}
}