题面
因为每个人只有一个要求,假如A要求在B左边,A就向B连一条边的话,就是一个带环树。
当然,有环直接无解。所以剩下的是森林。
我们把子树看成一个子问题,假如知道了子树内部的答案如何转移到父亲。设当前转移X,X的所有子树都应该安排在X左边,也就是首先有size[X]-1个空来安排,每安排一个子树剩下的空就减一些。安排第一个儿子C(size[X]-1,size[son1])* ans[son1],第二个(size[X]-size[son1]-1,size[son2])* ans[son2]等等。
为了处理方便,把所有没有要求的人的父亲设为一个虚根。直接计算改根的答案即可。
其实推出式子答案就是:n!/Πsize[i]{1<=i<=n}。
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
const int maxn=200010;
int n,m,mod;
ll jie[maxn],inv[maxn],sz[maxn];
bool vis[maxn],rt[maxn];
struct edge
{
int t;
edge *next;
}*con[maxn];
void ins(int x,int y)
{
edge *p=new edge;
p->t=y;
p->next=con[x];
con[x]=p;
}
int read()
{
int x=0;
char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') {x=x*10+ch-'0';ch=getchar();}
return x;
}
ll ksm(ll a,int b){ll r=1; for(;b;b>>=1){if(b&1) r=r*a%mod; a=a*a%mod;} return r;}
ll C(int n,int m){if(n<m) return 0; return jie[n]*inv[m]%mod*inv[n-m]%mod;}
void dfs(int v)
{
vis[v]=1;sz[v]=1;
for(edge *p=con[v];p;p=p->next)
{
dfs(p->t);
sz[v]+=sz[p->t];
}
}
ll dp(int v)
{
ll re=1;int rst=sz[v]-1;
for(edge *p=con[v];p;p=p->next)
{
re=re*C(rst,sz[p->t])%mod*dp(p->t)%mod;
rst-=sz[p->t];
}
return re;
}
int main()
{
int ca;
ca=read();
while(ca--)
{
for(int i=0;i<=n;i++)
con[i]=NULL;
memset(vis,0,sizeof(vis));
memset(rt,0,sizeof(rt));
n=read();m=read();mod=read();
jie[0]=inv[0]=1;
for(int i=1;i<=n;i++)
{
jie[i]=jie[i-1]*i%mod;
inv[i]=ksm(jie[i],mod-2);
}
for(int i=1;i<=m;i++)
{
int x=read(),y=read();
ins(y,x);
rt[x]=1;
}
for(int i=1;i<=n;i++)
if(!rt[i]) ins(0,i);
dfs(0);
bool pd=0;
for(int i=1;i<=n;i++) if(!vis[i]) {pd=1;break;}
if(pd) puts("0");
else printf("%lld\n",dp(0));
}
return 0;
}
推公式代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
const int maxn=200010;
int n,m,mod,tot;
ll jie[maxn],inv[maxn],sz[maxn];
bool vis[maxn],rt[maxn];
struct edge
{
int t;
edge *next;
}*con[maxn];
void ins(int x,int y)
{
edge *p=new edge;
p->t=y;
p->next=con[x];
con[x]=p;
}
int read()
{
int x=0;
char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') {x=x*10+ch-'0';ch=getchar();}
return x;
}
ll ksm(ll a,int b){ll r=1; for(;b;b>>=1){if(b&1) r=r*a%mod; a=a*a%mod;} return r;}
void dfs(int v)
{
tot+=1;
vis[v]=1;sz[v]=1;
for(edge *p=con[v];p;p=p->next)
{
dfs(p->t);
sz[v]+=sz[p->t];
}
}
int main()
{
int ca;
ca=read();
while(ca--)
{
for(int i=0;i<=n;i++)
con[i]=NULL;
memset(rt,0,sizeof(rt));
n=read();m=read();mod=read();
jie[0]=inv[0]=1;
for(int i=1;i<=n;i++)
jie[i]=jie[i-1]*i%mod;
for(int i=1;i<=m;i++)
{
int x=read(),y=read();
ins(y,x);
rt[x]=1;
}
for(int i=1;i<=n;i++)
if(!rt[i]) ins(0,i);
tot=0;
dfs(0);
if(tot<=n) puts("0");
else
{
ll ans=1;
for(int i=1;i<=n;i++) ans=ans*sz[i]%mod;
ans=ksm(ans,mod-2)*jie[n]%mod;
printf("%lld\n",ans);
}
}
return 0;
}