题目大意: 给出一棵树,在上面选定 m m m 条直上直下的链,要求每条链上都有至少一条边的权值为 1 1 1,边权可以为 0 , 1 0,1 0,1,问有多少种给边集选定权值的方案。
题解
一开始容易往容斥的方向想,套上树剖大概就有 32 32 32 分。然后就发现不太能优化……
考虑 dp \text{dp} dp,令 f i , j f_{i,j} fi,j 表示 i i i 往上第一条权值为 1 1 1 的边的深度为 j j j 时,子树的方案数。那么对于每条限制 x , y x,y x,y,相当于令 y y y 的 f y , p = 0 ( p ∈ [ d e p x , n ] ) f_{y,p}=0~(p\in[dep_x,n]) fy,p=0 (p∈[depx,n]),不妨记 l i m i lim_i limi 表示 i i i 往上最远的权值为 1 1 1 的边至多深度为多浅。
转移时,考虑自己与儿子之间的边为
1
1
1 还是
0
0
0,可以得到:
(
j
∈
[
l
i
m
i
,
d
e
p
i
]
)
f
i
,
j
=
∏
y
∈
s
o
n
(
f
y
,
j
+
f
y
,
d
e
p
y
)
(j\in[lim_i,dep_i])~f_{i,j}=\prod_{y\in son}(f_{y,j}+f_{y,dep_y})
(j∈[limi,depi]) fi,j=y∈son∏(fy,j+fy,depy)
这个可以用线段树合并来优化,贡献前令儿子的所有 f y , j f_{y,j} fy,j 加上 f y , d e p y f_{y,dep_y} fy,depy,贡献时恰好是让自己的第 j j j 位贡献父亲的第 j j j 位,这是完全对应的,用线段树维护第二维,维护区间加,区间乘即可。
代码如下:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define maxn 500010
#define mod 998244353
int n,m;
struct edge{int y,next;}e[maxn<<1];
int first[maxn],len=0;
void buildroad(int x,int y){e[++len]=(edge){y,first[x]};first[x]=len;}
int fa[maxn],dep[maxn];
void dfs(int x){
for(int i=first[x],y;i;i=e[i].next)
if((y=e[i].y)!=fa[x])fa[y]=x,dep[y]=dep[x]+1,dfs(y);
}
int lim[maxn];
struct node{
int lazy_add,lazy_mul;node *zuo,*you;
node():lazy_add(0),lazy_mul(1),zuo(NULL),you(NULL){}
void update(int Add,int Mul){
lazy_add=(1ll*lazy_add*Mul%mod+Add)%mod;
lazy_mul=1ll*lazy_mul*Mul%mod;
}
void pushdown(){
if(!zuo)zuo=new node();
if(!you)you=new node();
zuo->update(lazy_add,lazy_mul);
you->update(lazy_add,lazy_mul);
lazy_add=0,lazy_mul=1;
}
}*rt[maxn];
void change_add(node *now,int l,int r,int x,int y,int z){
if(l==x&&r==y)return now->update(z,1);
int mid=l+r>>1;now->pushdown();
if(y<=mid)change_add(now->zuo,l,mid,x,y,z);
else if(x>=mid+1)change_add(now->you,mid+1,r,x,y,z);
else change_add(now->zuo,l,mid,x,mid,z),change_add(now->you,mid+1,r,mid+1,y,z);
}
node *merge(node *x,node *y){
if(!x->zuo&&!x->you)swap(x,y);
if(!y->zuo&&!y->you){
x->update(0,y->lazy_add);
return x;
}
x->pushdown();y->pushdown();
x->zuo=merge(x->zuo,y->zuo);
x->you=merge(x->you,y->you);
return x;
}
int ask(node *now,int l,int r,int x){
if(l==r)return now->lazy_add;
int mid=l+r>>1;now->pushdown();
if(x<=mid)return ask(now->zuo,l,mid,x);
else return ask(now->you,mid+1,r,x);
}
void dp(int x){
rt[x]=new node();
change_add(rt[x],0,n,lim[x],dep[x],1);
for(int i=first[x];i;i=e[i].next){
int y=e[i].y;if(y==fa[x])continue;
dp(y);rt[x]=merge(rt[x],rt[y]);
}
if(x>1)change_add(rt[x],0,n,0,n,ask(rt[x],0,n,dep[x]));
}
int main()
{
scanf("%d",&n);
for(int i=1,x,y;i<n;i++){
scanf("%d %d",&x,&y);
buildroad(x,y);buildroad(y,x);
}
dfs(1);
scanf("%d",&m);
for(int i=1,x,y;i<=m;i++){
scanf("%d %d",&x,&y);
lim[y]=max(lim[y],dep[x]+1);
}
dp(1);printf("%d",ask(rt[1],0,n,0));
}