题目链接
点这里
思路来源
wls和dls视频
思路
首先,我们对于这个树上两个结点的关系分为两类,一类是有直接父子或祖先关系的结点对,这样的结点对在DFS序中的顺序是确定的,一定是父亲在前面,儿子在后面,那么这样的结点对,如果是逆序对,一定会出现在每一种DFS序中,所以,这样的逆序对的贡献就是这个树的DFS序的种类数。另外一类就是不具有直接父子或祖先关系的结点对,这样的结点对在每种DFS序中的顺序不固定,但是我们经过观察完全可以发现(我发现不了,杜爹说了我才发现了 ),这样的结点对要不然是逆序对,要不然不是逆序对,概率都是1/2,那么从期望的角度考虑,这样的结点对对于答案的贡献其实就是这个树的DFS序的种类数除以2(实际操作改成逆元,毕竟是在模)。
那么如何求解一颗树的DFS序种类数呢,其实很简单,对于一颗子树u,它的DFS序的种类数就是它的所有子树的方案数的乘积,再乘上u的子树个数的全排列(也就是阶乘),然后向上递推就行。第二个问题,如何求解两种点对的数目。首先我们知道,对于n个点能组合出来的总点对数是
n
∗
(
n
−
1
)
/
2
n*(n-1)/2
n∗(n−1)/2,那么只要求出其中一类的个数,另外一类就知道了,我们选择求解上述第一类的个数,令
d
[
i
]
d[i]
d[i] 表示以
i
i
i 为根节点的子树第一类结点对的数目,很容易观察得到,其实就是以
i
i
i 为根节点的子树中
i
i
i 有多少个孩子or孙子,那么这个自然也可以向上递推。第三个问题,如何计算第一类结点对中逆序对的个数,我们令
m
x
[
i
]
mx[i]
mx[i] 表示从结点
i
i
i 出发往上遍历直到根节点,有多少个比它大的结点数,这个在DFS的过程中在值域上建立树状数组就可以统计。这样这题就结束了,还有不懂的看代码吧。
#include <bits/stdc++.h>
#define ll long long
#define mod 1000000007
using namespace std;
const int N = 3e5 + 10;
int n, root;
vector<int> g[N];
ll fact[N], dp[N], mx[N], d[N], deg[N], c[N];
ll qmi(ll a, ll b, ll p)
{
ll res = 1;
a %= p;
while (b)
{
if (b & 1)
res = res * a % p;
b >>= 1;
a = a * a % p;
}
return res;
}
int lowbit(int x)
{
return x & (-x);
}
void add(int x, int val)
{
for (int i = x; i <= n; i += lowbit(i))
c[i] += val;
}
ll query(int x)
{
ll res = 0;
for (int i = x; i; i -= lowbit(i))
res += c[i];
return res;
}
void dfs(int u, int fa)
{
mx[u] = query(n) - query(u);
add(u, 1);
dp[u] = 1;
for (auto &x : g[u])
if (x != fa)
{
deg[u]++;
dfs(x, u);
dp[u] = dp[u] * dp[x] % mod;
d[u] += d[x];
}
d[u] += deg[u];
dp[u] = dp[u] * fact[deg[u]] % mod;
add(u, -1);
}
ll cal(ll x)
{
return x * (x - 1) % mod * qmi(2, mod - 2, mod) % mod;
}
void solve()
{
scanf("%d%d", &n, &root);
for (int i = 1; i < n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
fact[0] = 1;
for (ll i = 1; i <= n; i++)
fact[i] = fact[i - 1] * i % mod;
dfs(root, 0);
ll tmp = cal(n);
for (int i = 1; i <= n; i++)
tmp -= d[i], tmp %= mod;
// cout << tmp << '\n';
ll res = tmp * qmi(2, mod - 2, mod) % mod;
for (int i = 1; i <= n; i++)
res += mx[i], res %= mod;
res = res * dp[root] % mod;
cout << (res + mod) % mod;
}
int main()
{
int _ = 1;
// cin >> _;
while (_--)
{
solve();
}
return 0;
}