Description
给出一棵 n n n个节点的树,再给出 m m m只僵尸的位置 x i x_i xi以及能力 h i h_i hi,第 i i i条边可以等概率建立起高度为 [ l i , r i ] [l_i,r_i] [li,ri]的围墙,第 i i i只僵尸可以越过任何小于 h i h_i hi的围墙,问至少存在一个位置安全(即没有僵尸可以到达这个位置)的概率
Input
第一行一整数 T T T表示用例组数,每组用例首先输入两个整数 n , m n,m n,m表示点数和僵尸数量,之后 n − 1 n-1 n−1行每行输入四个整数 u , v , l i , r i u,v,l_i,r_i u,v,li,ri表示第 i i i条树边为 u ↔ v u\leftrightarrow v u↔v,可以建立围墙高度范围为 [ l i , r i ] [l_i,r_i] [li,ri],最后 m m m行每行两个整数 x i , h i x_i,h_i xi,hi表示第 i i i个点处有一个能力为 h i h_i hi的僵尸
( 1 ≤ T ≤ 5 , 1 ≤ n , m ≤ 2000 , 1 ≤ l i ≤ r i ≤ 1 0 9 , 1 ≤ h i ≤ 1 0 9 ) (1\le T\le 5,1\le n,m\le 2000,1\le l_i\le r_i\le 10^9,1\le h_i\le 10^9) (1≤T≤5,1≤n,m≤2000,1≤li≤ri≤109,1≤hi≤109)
Output
输出至少存在一个位置安全的概率,结果模 998244353 998244353 998244353
Sample Input
2
4 2
1 2 1 2
2 3 1 2
1 4 1 2
1 2
3 2
5 2
1 2 1 10
2 3 2 9
1 4 3 12
2 5 4 6
1 7
5 5
Sample Output
374341633
888437475
Solution
考虑所有位置都不安全的方案数,把僵尸按能力升序排,以 d p [ u ] [ i ] dp[u][i] dp[u][i]表示以 u u u为根的子树全部不安全,且 u u u子树中可以到达 u u u的最强僵尸编号为 i i i,以此考虑 u u u的子树,对于当前考虑的儿子 v v v,假设 u , v u,v u,v边被 i i i僵尸通过的方案数为 x i x_i xi,不被通过的方案数为 y i y_i yi,那么有三种情况:
1. v v v被 i i i干掉了, d p [ u ] [ i ] + = d p [ u ] [ i ] ⋅ d p [ v ] [ i ] ⋅ x i dp[u][i]+=dp[u][i]\cdot dp[v][i]\cdot x_i dp[u][i]+=dp[u][i]⋅dp[v][i]⋅xi
2. v v v被其子树内弱于 i i i的僵尸干掉,此时显然 i i i僵尸不会在 v v v子树中,为使干掉 v v v的最强僵尸不超过 i i i, u , v u,v u,v之间需要阻碍 i i i僵尸通过,故有 d p [ u ] [ i ] + = d p [ u ] [ i ] ⋅ y i ⋅ ∑ j < i d p [ v ] [ j ] dp[u][i]+=dp[u][i]\cdot y_i\cdot \sum\limits_{j<i}dp[v][j] dp[u][i]+=dp[u][i]⋅yi⋅j<i∑dp[v][j]
3. v v v被其子树内强于 i i i的僵尸干掉,此时不能让这些强于 i i i的僵尸通过 u , v u,v u,v之间的边去干掉 u u u,故 u , v u,v u,v之间需要阻碍这些更强的僵尸通过,故有 d p [ u ] [ i ] + = d p [ u ] [ i ] ⋅ ∑ j ≥ i d p [ v ] [ j ] ⋅ y j dp[u][i]+=dp[u][i]\cdot \sum\limits_{j\ge i}dp[v][j]\cdot y_j dp[u][i]+=dp[u][i]⋅j≥i∑dp[v][j]⋅yj
前缀和优化一下,第二部分从弱僵尸到强僵尸考虑,第三部分从强僵尸到弱僵尸考虑即可,答案即为
1
−
∑
i
=
1
m
d
p
[
1
]
[
i
]
∏
i
=
1
n
−
1
(
r
i
−
l
i
+
1
)
1-\frac{\sum\limits_{i=1}^mdp[1][i]}{\prod\limits_{i=1}^{n-1}(r_i-l_i+1)}
1−i=1∏n−1(ri−li+1)i=1∑mdp[1][i]
时间复杂度
O
(
n
m
)
O(nm)
O(nm)
Code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long ll;
typedef pair<int,int>P;
#define maxn 2005
#define mod 998244353
int mul(int x,int y)
{
ll z=1ll*x*y;
return z-z/mod*mod;
}
int add(int x,int y)
{
x+=y;
if(x>=mod)x-=mod;
return x;
}
int Pow(int x,int y)
{
int ans=1;
while(y)
{
if(y&1)ans=mul(ans,x);
x=mul(x,x);
y>>=1;
}
return ans;
}
#define id second
#define val first
P a[maxn];
vector<P>g[maxn];
int T,n,m,l[maxn],r[maxn],dp[maxn][maxn],vis[maxn][maxn],temp[maxn];
void dfs(int u,int fa)
{
int pos=1;
for(int i=1;i<=m;i++)
if(a[i].id==u)
{
vis[u][i]=1;
pos=i;
}
else vis[u][i]=0;
for(int i=pos;i<=m;i++)dp[u][i]=1;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i].first,t=g[u][i].second;
if(v==fa)continue;
dfs(v,u);
for(int j=1;j<=m;j++)temp[j]=dp[u][j],dp[u][j]=0;
int res=0;
for(int j=1;j<=m;j++)
{
int x=max(0,min(a[j].val-1,r[t])-l[t]+1),y=r[t]-l[t]+1-x;
dp[u][j]=add(dp[u][j],mul(x,mul(temp[j],dp[v][j])));
if(vis[v][j])res=add(res,dp[v][j]);
else dp[u][j]=add(dp[u][j],mul(mul(y,res),temp[j]));
}
res=0;
for(int j=m;j>=1;j--)
{
int x=max(0,min(a[j].val-1,r[t])-l[t]+1),y=r[t]-l[t]+1-x;
if(vis[v][j])res=add(res,mul(dp[v][j],y));
else dp[u][j]=add(dp[u][j],mul(res,temp[j]));
}
for(int j=1;j<=m;j++)vis[u][j]|=vis[v][j];
}
}
int main()
{
scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)g[i].clear();
memset(dp,0,sizeof(dp));
int ans=1;
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d%d%d",&u,&v,&l[i],&r[i]);
ans=mul(ans,r[i]-l[i]+1);
g[u].push_back(P(v,i)),g[v].push_back(P(u,i));
}
for(int i=1;i<=m;i++)scanf("%d%d",&a[i].id,&a[i].val);
sort(a+1,a+m+1);
dfs(1,0);
int res=0;
for(int i=1;i<=m;i++)res=add(res,dp[1][i]);
res=add(ans,mod-res);
printf("%d\n",mul(res,Pow(ans,mod-2)));
}
return 0;
}