链接
先考虑一条链上的特殊情形
如果是一条链,而且边的顺序是从后往前,那么我很容易写出 d p dp dp:
定义 f i , 0 / 1 f_{i,0/1} fi,0/1为前 i i i个位置已经决策完了,并且最后一个点有没有被选,的方案数。
转移就是:
f
i
,
0
=
f
i
−
1
,
0
+
f
i
−
1
,
1
f_{i,0}=f_{i-1,0}+f_{i-1,1}
fi,0=fi−1,0+fi−1,1
f
i
,
1
=
f
i
−
1
,
0
f_{i,1}=f_{i-1,0}
fi,1=fi−1,0
再考虑一条直线上的一般情形
如果我给出边的顺序不是按照从前往后,而是乱序的,那么第 i i i条边所面临的决策也会受到后面的影响
看似有后效性,但其实会发现这个题目中决策只会影响到相邻的位置,因此有一个套路就是,多开一维记录后面对前面的影响
f i , 0 / 1 , 0 / 1 f_{i,0/1,0/1} fi,0/1,0/1表示决策第 i i i个位置前, i i i线段的右端点是 0 / 1 0/1 0/1,第 i i i个位置决策后, i i i线段的右端点是 0 / 1 0/1 0/1,的方案数。
由于 i , 1 , 0 i,1,0 i,1,0这个状态无效因此我把它丢掉
f i , 0 , f i , 1 , f i , 2 f_{i,0},f_{i,1},f_{i,2} fi,0,fi,1,fi,2依次表示 ( 0 → 0 ) , ( 0 → 1 ) , ( 1 → 1 ) (0 \rightarrow 0),(0 \rightarrow 1),(1 \rightarrow 1) (0→0),(0→1),(1→1)的方案数
假设第 i i i条边的时间戳是 t i t_i ti
那么:
t
i
>
t
i
+
1
t_i>t_{i+1}
ti>ti+1时,
f
i
,
0
=
f
i
+
1
,
0
+
f
i
+
1
,
1
f_{i,0} = f_{i+1,0} + f_{i+1,1}
fi,0=fi+1,0+fi+1,1
f
i
,
1
=
f
i
+
1
,
0
f_{i,1} = f_{i+1,0}
fi,1=fi+1,0
f
i
,
2
=
f
i
+
1
,
0
f_{i,2} = f_{i+1,0}
fi,2=fi+1,0
t
i
<
t
i
+
1
t_i<t_{i+1}
ti<ti+1时,
f
i
,
0
=
f
i
+
1
,
2
f_{i,0} = f_{i+1,2}
fi,0=fi+1,2
f
i
,
1
=
f
i
+
1
,
0
+
f
i
+
1
,
1
f_{i,1} = f_{i+1,0} + f_{i+1,1}
fi,1=fi+1,0+fi+1,1
f
i
,
2
=
f
i
+
1
,
0
+
f
i
+
1
,
1
f_{i,2} = f_{i+1,0} + f_{i+1,1}
fi,2=fi+1,0+fi+1,1
最后推广到树上
推广到树上之后,就变成自下而上的 d p dp dp,叶子节点的初始状态直接就是 ( 1 , 1 , 1 ) (1,1,1) (1,1,1),而 d p i , 0 / 1 / 2 dp_{i,0/1/2} dpi,0/1/2的含义也变成了“子树中边的决策的方案数”
做法也有一点变化,就是把与 u u u直接相连的边按照时间戳排个序,转移的时候基本思想还是不变,请读者自行瞎搞搞应该就能出来了
代码
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 200010
#define maxe 400010
#define cl(x) memset(x,0,sizeof(x))
#define rep(_,__) for(_=1;_<=(__);_++)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define mod 998244353ll
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
ll c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
vector<ll> to[maxn];
ll n, dp[maxn][3], p0[maxn], p1[maxn], p2[maxn];
void dfs(ll fa, ll u)
{
ll i, j, sz=to[u].size()-1, cnt;
if(sz==0)
{
dp[u][0]=dp[u][1]=dp[u][2]=1;
return;
}
for(auto v:to[u])if(v!=fa)dfs(u,v);
p0[0]=1;
j=0;
for(auto v:to[u])if(v!=fa)
{
p0[j+1]=p0[j]*dp[v][0]%mod;
p1[j+1]=dp[v][1];
j++;
}
p2[sz+1]=1;
j=sz;
for(auto it=to[u].rbegin();it!=to[u].rend();it++)
{
auto v=*it;
if(v!=fa)
{
p2[j]=p2[j+1]*dp[v][2]%mod;
j--;
}
}
cnt=0;
for(auto v:to[u])if(v!=fa)cnt++;else break;
rep(i,cnt)
(dp[u][0]+=p0[i-1]*p1[i]%mod*p2[i+1])%=mod;
(dp[u][0]+=p0[cnt]*p2[cnt+1])%=mod;
for(i=cnt+1;i<=sz;i++)
(dp[u][1]+=p0[i-1]*p1[i]%mod*p2[i+1])%=mod;
(dp[u][1]+=p0[sz])%=mod;
rep(i,sz)
(dp[u][2]+=p0[i-1]*p1[i]%mod*p2[i+1])%=mod;
(dp[u][2]+=p0[sz])%=mod;
}
int main()
{
ll i;
n=read();
to[1].emb(n+1);
to[n+1].emb(1);
rep(i,n-1)
{
auto u=read(), v=read();
to[u].emb(v);
to[v].emb(u);
}
dfs(n+1,1);
printf("%lld",dp[1][2]);
return 0;
}