参考博客
http://blog.csdn.net/firenet1/article/details/47445921
http://blog.csdn.net/pibaixinghei/article/details/52783432
有两种方法,一种是计数DP,另一种是概率DP。
计数DP:
应该都能想到dp[i][j]表示以i为根的子树,有j个领导。接下来考虑状态转移。
自己一开始考虑枚举分配方案,就是对dp[i][j],枚举j个领导如何分配给儿子节点,但是这样时间复杂度肯定是不能接受的。
事实上这样的枚举浪费了组合数公式,我们可以考虑将其中组合方面的枚举提取成公式计算出来。
降低时间复杂度的方法有很多,可以考虑交换枚举顺序,改变枚举量,动态改变枚举范围,维护一些值等技巧来降低时间复杂度。
由于各个子树的方案相互独立,因此我们可以用乘法原理,逐个考虑每个儿子,用组合数公式对每一个儿子的所有可能组合计算完,然后抛弃之。
时间复杂度O(n^2),稍加改造就可以O(nk)
概率DP:
我们也可以计算出概率,然后乘以总次数。
概率一般都是小数或者分数,如果用double肯定会有精度问题,如果用分数一定会爆long long。
事实上我们可以用逆元来解决这个问题。
逆元可以解决过程是小数,但是初始和结果都是整数的问题。
时间复杂度O(nk)
计数DP代码
#include<stdio.h>
#include<vector>
using namespace std;
const int maxn = 1010;
const int mod = 1000000007;
int n,k;
vector<int>G[maxn];
int dp[maxn][maxn];
int siz[maxn];
int C[maxn][maxn];
int tp[maxn];
void read()
{
scanf("%d %d",&n,&k);
for(int i=1;i<=n;i++) G[i].clear();
int u,v;
for(int i=1;i<n;i++)
{
scanf("%d %d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
}
void dfs1(int u,int f)
{
siz[u]=1;
for(int i=0;i<(int)G[u].size();i++)
{
int v = G[u][i];
if(v==f) continue;
dfs1(v,u);
siz[u]+=siz[v];
}
}
void dfs2(int u,int f)
{
int e = 1;
dp[u][1]=1;
dp[u][0]=siz[u]-1;
for(int ii=0;ii<(int)G[u].size();ii++)
{
int v = G[u][ii];
if(v==f) continue;
dfs2(v,u);
for(int i=0;i<=e+siz[v];i++)
tp[i]=0;
for(int i=0;i<=siz[v];i++)
dp[v][i]=1ll*dp[v][i]*C[siz[u]-e][siz[v]]%mod;
for(int i=0;i<=siz[v];i++)
for(int j=0;j<=e;j++)
tp[i+j]=(tp[i+j]+1ll*dp[u][j]*dp[v][i])%mod;
e+=siz[v];
for(int i=0;i<=e;i++)
dp[u][i]=tp[i];
}
}
void solve()
{
read();
dfs1(1,0);
dfs2(1,0);
printf("%d\n",dp[1][k]);
}
void init()
{
C[0][0]=1;
for(int i=1;i<maxn;i++)
{
C[i][0]=1;
for(int j=1;j<=i;j++)
C[i][j]=(1ll*C[i-1][j]+C[i-1][j-1])%mod;
}
}
int main()
{
init();
int T;
scanf("%d",&T);
for(int t=1;t<=T;t++)
{
printf("Case #%d: ",t);
solve();
}
return 0;
}
概率DP代码
#include<stdio.h>
#include<vector>
using namespace std;
const int mod = 1000000007;
const int maxn = 1010;
vector<int>G[maxn];
int dp[maxn][maxn];
int mp(int x,int n)
{
int ret=1;
while(n)
{
if(n&1) ret=1ll*ret*x%mod;
x=1ll*x*x%mod;
n>>=1;
}
return ret;
}
int inv[maxn];
int siz[maxn];
int fac[maxn];
int n,k;
void read()
{
scanf("%d %d",&n,&k);
for(int i=1;i<=n;i++) G[i].clear();
int u,v;
for(int i=1;i<n;i++)
{
scanf("%d %d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
}
void dfs(int u,int f)
{
siz[u]=1;
for(int i=0;i<(int)G[u].size();i++)
{
int v = G[u][i];
if(v==f) continue;
dfs(v,u);
siz[u]+=siz[v];
}
}
void solve()
{
read();
dfs(1,0);
dp[0][0]=1;
for(int i=1;i<=n;i++)
{
dp[i][0]=1ll*dp[i-1][0]*(siz[i]-1)%mod*inv[siz[i]]%mod;
for(int j=1;j<=k;j++)
dp[i][j]=(1ll*dp[i-1][j]*(siz[i]-1)%mod*inv[siz[i]]%mod+1ll*dp[i-1][j-1]*inv[siz[i]]%mod)%mod;
}
printf("%d\n",int(1ll*dp[n][k]*fac[n]%mod));
}
void init()
{
fac[0]=1;
for(int i=1;i<maxn;i++)
{
inv[i]=mp(i,mod-2);
fac[i]=1ll*fac[i-1]*i%mod;
}
}
int main()
{
init();
int T;
scanf("%d",&T);
for(int t=1;t<=T;t++)
{
printf("Case #%d: ",t);
solve();
}
return 0;
}