题面
题意
给出一棵树,多次询问,每次询问给出一个根节点和k个点,问要求将它们划分为至多m个集合,要求每个集合中不包含存在祖先关系的点,则一共有几种方法。
做法
考虑单次询问怎么做,可以发现一个非常重要的性质,一个点的两个不同祖先不可能在同一个集合中,因此可以先将所有点根据到根节点的距离排序,记
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j]表示前i个集合划分为j个集合的方案数,然后可得dp转移:
d
p
[
i
]
[
j
]
=
d
p
[
i
−
1
]
[
j
−
1
]
+
(
j
−
f
(
i
)
)
∗
d
p
[
i
−
1
]
[
j
]
,
f
(
i
)
dp[i][j]=dp[i-1][j-1]+(j-f(i))*dp[i-1][j],f(i)
dp[i][j]=dp[i−1][j−1]+(j−f(i))∗dp[i−1][j],f(i)表示前
i
−
1
i-1
i−1个点中是点
i
i
i的祖先的数量,显然
f
(
i
)
f(i)
f(i)可以用树状数组+dfs序求出,然后即可dp得出答案。
代码
#include<bits/stdc++.h>
#define ll long long
#define LG 17
#define N 100100
#define M 1000000007
using namespace std;
int n,m,rt,tt,in[N],out[N],num[N],fa[N][20],deep[N],d[N],dp[N][310];
struct Sz
{
int num[N];
vector<int>cle;
inline int lb(int u){return u&(-u);}
inline void add(int u,int v){for(;u<=n;u+=lb(u)) num[u]+=v,cle.push_back(u);}
inline int ask(int u){int res=0;for(;u;u-=lb(u)) res+=num[u];return res;}
inline void clear(){int i;for(i=0;i<cle.size();i++) num[cle[i]]=0;cle.clear();}
}sz;
vector<int>to[N],son[N];
inline bool cmp(int u,int v){return d[u]<d[v];}
void dfs(int now,int last)
{
int i,t;
in[now]=++tt;
for(i=0;i<to[now].size();i++)
{
t=to[now][i];
if(t==last) continue;
son[now].push_back(t);
fa[t][0]=now;
deep[t]=deep[now]+1;
dfs(t,now);
}
out[now]=tt;
}
inline int ts(int u,int v)
{
int l,r,mid;
for(l=0,r=son[u].size();l<r;)
{
mid=((l+r)>>1);
if(in[son[u][mid]]<=in[v]) l=mid+1;
else r=mid;
}
return son[u][l-1];
}
inline int lca(int u,int v)
{
int i,j;
if(deep[u]<deep[v]) swap(u,v);
for(i=LG;deep[u]!=deep[v];i--)
{
if(deep[fa[u][i]]>=deep[v])
u=fa[u][i];
}
if(u==v) return u;
for(i=LG;i>=0;i--)
{
if(fa[u][i]!=fa[v][i])
{
u=fa[u][i];
v=fa[v][i];
}
}
return fa[u][0];
}
inline int dis(int u,int v)
{
int l=lca(u,v);
return deep[u]+deep[v]-2*deep[l];
}
int main()
{
int i,j,p,q;
cin>>n>>m;
for(i=1;i<n;i++)
{
scanf("%d%d",&p,&q);
to[p].push_back(q);
to[q].push_back(p);
}
deep[1]=1;
dfs(1,-1);
for(i=1;i<=LG;i++)
{
for(j=1;j<=n;j++)
{
fa[j][i]=fa[fa[j][i-1]][i-1];
}
}
while(m--)
{
scanf("%d%d%d",&p,&q,&rt);
sz.clear(),dp[0][0]=1;
for(i=1;i<=p;i++) scanf("%d",&num[i]),d[num[i]]=dis(num[i],rt);
sort(num+1,num+p+1,cmp);
for(i=1;i<=p;i++)
{
int as=sz.ask(in[num[i]]);
for(j=as+1;j<=q;j++)
{
dp[i][j]=((ll)dp[i-1][j-1]+(ll)dp[i-1][j]*(j-as)%M)%M;
}
if(num[i]==rt) sz.add(1,1);
else if(in[rt]>=in[num[i]] && in[rt]<=out[num[i]])
{
int t=ts(num[i],rt);
sz.add(1,1),sz.add(in[t],-1);
sz.add(out[t]+1,1);
}
else
{
sz.add(in[num[i]],1);
sz.add(out[num[i]]+1,-1);
}
}
ll ans=0;
for(i=1;i<=q;i++) ans=(ans+dp[p][i])%M;
for(i=0;i<=p;i++) memset(dp[i],0,sizeof(dp[i]));
printf("%lld\n",ans);
}
}