题意:boss和employee的关系可以构成一棵树,树上每个节点是其子树的boss,和其ascent的employee。现在发送invitation,要求boss必须比其employee先收到,问有多少种发送invitaion的方式。
结果要对1e7取模,一看就是排列组合推公式。
推公式也是从树的递推关系入手,假设三个子树a,b,c有同一个parent,那么a,b,c之间是不会影响的。令其孩子节点个数分别是n1,n2,n3,如果总共有n个位置去放这三个子树(n=n1+n2+n3),那么就是n中选n1个位置先放a,剩下的n-n1个位置中选n2个放b,剩下的n-n1-n2个位子中放c。如果子树a是叶子,返回1。
递推公式是f(parent)=f(a)*f(b)*f(c)*C(n,n1)*C(n-n1,n2)*C(n-n1-n2,n3)。
如果输入是dis-connected,相当于加一个virtual parent,将输入中的几棵树作为并列的子树处理。
这一题必须O(NlogN)才能过。所以需要快速求组合数。按照杨氏三角形预处理会超时,正解是把阶乘先存起来,最后用逆元和快速幂求除法。
#include <bits/stdc++.h>
using namespace std;
string ltrim(const string &);
string rtrim(const string &);
vector<string> split(const string &);
/*
* Complete the 'invitations' function below.
*
* The function is expected to return an INTEGER.
* The function accepts following parameters:
* 1. INTEGER n
* 2. 2D_INTEGER_ARRAY pairs
*/
const int maxn=200010;
const int mod=1e9+7;
vector<vector<int> >mp;
long long save[maxn];
//map<pair<int,int>,long long >combmp;
map<pair<int,int>,int >indexmp;
int parent[maxn];//# of parents
int children[maxn];//# of children
long long Pow(long long a,long long b)
{
long long s=1;
long long t=1;
while(b)
{
if(b&t)
{
s=(s*a)%mod;
}
a=(a*a)%mod;
b=b>>1;
}
return s;
}
long long fac[maxn];
void getfac()
{
fac[0]=1;
for (int i=1; i<maxn; i++)
fac[i]=(fac[i-1]*i)%mod;
}
long long inv(long long a) {
return Pow(a, mod - 2);
}
long long Comb(int n,int m)
{
if(n==m)
{
return 1;
}
if(m==0)
{
return 1;
}
if (n<m)
{
return 0;
}
return ((fac[n]*inv(fac[m]))%mod*inv(fac[n-m])%mod)%mod;
}
void init()
{
for(int i=0;i<maxn;i++)
{
mp.push_back(vector<int>());
}
getfac();
// memset(arr,false,sizeof(arr));
// prim=produce_prim_number();
// combmp.clear();
// indexmp.clear();
// memset(comb,0,sizeof(comb));
}
int calc_child(int root)
{
if(mp[root].size()==0)
{
children[root]=0;
return children[root];
}
for(int i=0;i<mp[root].size();i++)
{
children[root]+=1+calc_child(mp[root][i]);
// cout<<"root "<<root<<" child "<<mp[root][i]<<" num "<<children[mp[root][i]]<<endl;
}
return children[root];
}
// long long save[maxn][maxn];
long long dfs(int root)
{
if(save[root]!=0)
{
return save[root];
}
if(mp[root].size()==0)
{
save[root]=1;
return 1;
}
long long ret=0;
long long tmp=1;
vector<int>presum=vector<int>();
for(int j=mp[root].size()-1;j>=0;j--)
{
if(j==mp[root].size()-1)
{
presum.push_back(children[mp[root][j]]+1);
}
else
{
// cout<<"add "<<children[mp[root][j]]<<" "<<presum[mp[root].size()-1-j-1]<<endl;
presum.push_back(children[mp[root][j]]+1+presum[mp[root].size()-1-j-1]);
}
}
// for(int j=0;j<presum.size();j++)
// {
// cout<<"root "<<root<<" presum "<<presum[j]<<endl;
// }
for(int j=0;j<mp[root].size();j++)
{
// cout<<"root "<<root<<" "<<mp[root][j]<<" "<<presum[mp[root].size()-1-j]<<endl;
// tmp*=Combination(presum[mp[root].size()-1-j],1+children[mp[root][j]])%mod;
tmp*=Comb(presum[mp[root].size()-1-j],1+children[mp[root][j]])%mod;
tmp%=mod;
}
for(int j=0;j<mp[root].size();j++)
{
tmp*=dfs(mp[root][j])%mod;
tmp%=mod;
// cout<<"root "<<root<<" remain slot "<<i<<" child "<<mp[root][j]<<" prod "<<<<endl;
}
ret+=tmp;
ret%=mod;
save[root]=ret;
// cout<<"root "<<root<<" ret "<<ret<<endl;
return ret;
}
int invitations(int n, vector<vector<int>> pairs) {
long long ans=1;
memset(parent,0,sizeof(parent));
memset(children,0,sizeof(children));
memset(save,0,sizeof(save));
for(int i=0;i<mp.size();i++)
{
mp[i].clear();
}
// mp.clear();
// for(int i=0;i<=n;i++)
// {
// mp.push_back(vector<int>());
// }
for(int i=0;i<pairs.size();i++)
{
parent[pairs[i][1]]=1;
// children[vector[i][0]]++;
mp[pairs[i][0]].push_back(pairs[i][1]);
}
for(int i=1;i<=n;i++)
{
if(parent[i]==1)
{
continue;
}
calc_child(i);
}
int tot=n;
for(int i=1;i<=n;i++)
{
// cout<<"children "<<i<<" "<<children[i]<<endl;
if(parent[i]==1)//may be multiple roots
{
continue;
}
//connected_num++;
ans=((ans%mod)*dfs(i))%mod;
// cout<<"add comb "<<tot<<" "<<children[i]+1<<endl;
ans*=Comb(tot,children[i]+1)%mod;
ans%=mod;
tot-=children[i]+1;
}
ans%=mod;
// cout<<"ans "<<ans<<endl;
return ans;
}
int main()
{
init();
ofstream fout(getenv("OUTPUT_PATH"));
string tc_temp;
getline(cin, tc_temp);
int tc = stoi(ltrim(rtrim(tc_temp)));
for (int tc_itr = 0; tc_itr < tc; tc_itr++) {
string first_multiple_input_temp;
getline(cin, first_multiple_input_temp);
vector<string> first_multiple_input = split(rtrim(first_multiple_input_temp));
int n = stoi(first_multiple_input[0]);
int m = stoi(first_multiple_input[1]);
vector<vector<int>> pairs(m);
for (int i = 0; i < m; i++) {
pairs[i].resize(2);
string pairs_row_temp_temp;
getline(cin, pairs_row_temp_temp);
vector<string> pairs_row_temp = split(rtrim(pairs_row_temp_temp));
for (int j = 0; j < 2; j++) {
int pairs_row_item = stoi(pairs_row_temp[j]);
pairs[i][j] = pairs_row_item;
}
}
int result = invitations(n, pairs);
fout << result << "\n";
}
fout.close();
return 0;
}
string ltrim(const string &str) {
string s(str);
s.erase(
s.begin(),
find_if(s.begin(), s.end(), not1(ptr_fun<int, int>(isspace)))
);
return s;
}
string rtrim(const string &str) {
string s(str);
s.erase(
find_if(s.rbegin(), s.rend(), not1(ptr_fun<int, int>(isspace))).base(),
s.end()
);
return s;
}
vector<string> split(const string &str) {
vector<string> tokens;
string::size_type start = 0;
string::size_type end = 0;
while ((end = str.find(" ", start)) != string::npos) {
tokens.push_back(str.substr(start, end - start));
start = end + 1;
}
tokens.push_back(str.substr(start));
return tokens;
}