We are given a rooted tree of n vertices. The vertices are to be labeled with numbers 1, 2,..., n so that each label is unique and the heap condition holds, i.e. the label of any vertex is less than the label of its parent. How many such labellings exist? Since this number may be quite large, calculate only its remainder modulo m .
Input
The input contains several tree descriptions. The first line contains the number of input trees t (t250) . Each tree description begins with a line containing the size of the tree n (1n500000) and an integer m(2m109) . n - 1 lines follow, i -th of which contains p(i + 1) , the number of the parent of the i + 1 -th vertex (1p(i + 1)i) . Vertex number 1 will be the root in each tree, so its parent will not be given. Total size of the input will not exceed 50MB.
Output
For each tree output the number of its valid labellings modulo given m .
Explanation for sample: The 8 possible labellings from the last example test case are as follows:
Sample Input
4 3 1000000 1 1 4 1000000 1 1 1 5 1000000 1 2 3 4 5 1000000 1 1 3 3
Sample Output
2 6 1 8
思路:明显的树形dp,我们用dp[u] 表示以u为子树放上(1,2......size[u])的方案数。怎么转移? 考虑一下假设子树的方案数都求出来了,那么为他们加一个根结点会怎样? 显然这个根结点的权值必须是1,而各个子树在剩下数中选一些数按照之前求出来的方案放就行了。 所以只需要将C(rest,son[v]) * dp[v] 乘起来就行了。这题麻烦的只是模的数是会变的,而且不一定是素数? 那么我们不能直接用逆元,但是我们可以根据m的素因子分开来处理就行了,即,把运算的数分成与m互素的部分和不是互素的部分,互素的部分可以直接计算,而不是互素的地方我们直接带如指数来运算就行了。。。。应该有更好的办法。。。别人都是1s以内过的,我的跑了很久。。。。
代码:
#include <iostream>
#include <vector>
#include <algorithm>
#include <string.h>
#include <cstring>
#include <stdio.h>
#include <cmath>
#include <math.h>
#define rep(i,a,b) for(int i=(a);i<(b);++i)
#define rrep(i,b,a) for(int i = (b); i >= (a); --i)
#define clr(a,x) memset(a,(x),sizeof(a))
#define LL long long
#define eps 1e-10
using namespace std;
const int maxn = 500000 + 5;
int n,m;
int pe[maxn],tot;
bool isp[maxn];
void read_int(int & x)
{
char ch = getchar();
while (ch < '0' || ch > '9') ch = getchar();
x = ch - '0'; ch = getchar();
while ('0' <= ch && ch <= '9') {
x = 10 * x + ch - '0';
ch = getchar();
}
}
LL qpow(LL base,LL p,LL mod)
{
LL ret = 1;
while (p) {
if (p & 1) ret = ret * base % mod;
base = base * base % mod;
p >>= 1;
}
return ret;
}
struct Node
{
int v;
Node * next;
}*first[maxn],edges[maxn<<1];
int ptr;
void add(int u,int v)
{
edges[++ptr].v = v;
edges[ptr].next = first[u];
first[u] = &edges[ptr];
}
void input()
{
ptr = 0; clr(first,0);
rep(v,2,n+1) {
int u; //scanf("%d",&u);
read_int(u);
add(u,v);
}
}
void pre_init()
{
rep(i,2,maxn) isp[i] = true;
rep(i,2,maxn) if (isp[i]) {
pe[tot++] = i;
for(int j = i + i; j < maxn; j += i)
isp[j] = false;
}
}
vector<int> pfac;
struct DiyVal
{
int sum[10];
LL a;
DiyVal()
{
clr(sum,0);
a = 1;
}
DiyVal operator *= (const DiyVal& dv)
{
rep(i,0,pfac.size()) sum[i] += dv.sum[i];
a = a * dv.a % m;
}
DiyVal operator * (const DiyVal & dv)
{
DiyVal ret = *this;
rep(i,0,pfac.size()) ret.sum[i] += dv.sum[i];
ret.a = ret.a * dv.a % m;
return ret;
}
}A[maxn],rev[maxn],dp[maxn];
DiyVal C(int n,int m)
{
DiyVal ret;
ret *= A[n];
ret *= rev[m];
ret *= rev[n-m];
return ret;
}
int q[maxn];
int son[maxn];
void bfs()
{
int l = 0 ,r = 0;
q[r++] = 1;
while (l < r) {
int u = q[l++];
for(Node * p = first[u]; p ; p = p->next) {
int v = p->v;
q[r++] = v;
}
}
}
void solve()
{
pfac.clear();
LL x = m;
int phi = m;
rep(i,0,tot) {
LL v = pe[i];
if (v * v > x) break;
if (x % v == 0) {
phi = phi / v * (v - 1);
pfac.push_back(v);
do x /= v;
while (x % v == 0);
}
}
if (x > 1) {
pfac.push_back(x);
phi = phi / x * (x - 1);
}
rep(i,1,n+1) {
int x = i;
A[i] = A[i-1];
rep(j,0,pfac.size()) {
int y = pfac[j];
while (x % y == 0) {
++A[i].sum[j];
x /= y;
}
}
A[i].a = A[i].a * x % m;
rev[i].a = qpow(A[i].a,phi-1,m);
rep(j,0,pfac.size()) rev[i].sum[j] = -A[i].sum[j];
}
bfs();
int S;
rrep(i,n-1,0) {
int u = q[i];
dp[u] = DiyVal();
S = 0;
for(Node * p = first[u]; p ; p = p->next) {
int v = p->v;
S += son[v];
}
son[u] = S+1;
for(Node * p = first[u] ; p ; p = p->next) {
int v = p->v;
dp[u] = dp[u] * dp[v] * C(S,son[v]);
S -= son[v];
}
}
LL ans = dp[1].a;
rep(i,0,pfac.size()) {
ans = ans * qpow(pfac[i],dp[1].sum[i],m) % m;
}
printf("%lld\n",ans);
}
void Getinput()
{
freopen("in.txt","w",stdout);
n = 500000; m = n+7;
printf("1\n%d %d\n",n,m);
rep(i,1,n) printf("1\n");
}
int main()
{
//Getinput();return 0;
#ifdef ACM
freopen("in.txt", "r", stdin);
// freopen("out.txt","w",stdout);
#endif // ACM
pre_init();
int T; cin >> T;
while (T--) {
read_int(n); read_int(m);
input();
solve();
}
}