2019 ICPC 徐州网络赛 J.Random Access Iterator
题目大意:给你n个点和n-1条边(树形结构),保证1为根节点,通过以下方式dfs遍历:
询问dfs到最深节点的概率(有多个最深节点则任意一个即可),答案对1e9+7取模。
解法:比赛的时候最后一个小时开的这道概率题,最后10分钟AC了。看上去有些困难,其实就是一个dp的过程。先一遍dfs找到每个节点的深度和孩子数,然后令最深节点的dp值为1.(dp[i]表示dfs到i节点之后成功dfs到最深点的概率)接下来从最下面开始更新dp: 对于任意一个节点u,设它有x个孩子,分别的dp值是dp[v1],dp[v2],dp[v3]…dp[vx];有规则可知dfs到u节点后,会重复进行x次dfs,直接求x次后成功的概率比较困难,考虑反面,求x次全部失败的概率。x次事件独立,所以求一次失败概率再x次方即可。一次失败很好求,就是(1-dp[v1])*(1-dp[v2])…(1-dp[vx])/x。
下面是AC代码:
#include <bits/stdc++.h>
using namespace std;
using namespace chrono;
const int N = 2000005;
const int M = 1000000007;
const int INF = 0x3f3f3f3f;
const double PI = acos(-1);
const double eps = 1e-8;
#define ms(x, y) memset((x), (y), sizeof(x))
#define mc(x, y) memcpy((x), (y), sizeof(y))
typedef long long ll;
typedef unsigned long long ull;
#define fi first
#define se second
#define mp make_pair
typedef pair<int, int> pii;
typedef pair<ll, int> pli;
#define bg begin
#define ed end
#define pb push_back
#define al(x) (x).bg(), (x).ed()
#define st(x) sort(al(x))
#define un(x) (x).erase(unique(al(x)), (x).ed())
#define fd(x, y) (lower_bound(al(x), (y)) - (x).bg() + 1)
#define ls(x) ((x) << 1)
#define rs(x) (ls(x) | 1)
template <class T>
bool read(T & x) {
char c;
while (!isdigit(c = getchar()) && c != '-' && c != EOF);
if (c == EOF) return false;
T flag = 1;
if (c == '-') { flag = -1; x = 0; } else x = c - '0';
while (isdigit(c = getchar())) x = x * 10 + c - '0';
x *= flag;
return true;
}
template <class T, class ...R>
bool read(T & a, R & ...b) {
if (!read(a)) return false;
return read(b...);
}
mt19937 gen(steady_clock::now().time_since_epoch().count());
struct edge { int to, next; } e[N];
int head[N], cnt = 0, sz[N], dep[N], leaf[N];
ll dp[N];
ll qpow(ll a, ll n, ll p) {
ll r = 1;
for (a %= p; n; n >>= 1, (a *= a) %= p)
if (n & 1) (r *= a) %= p;
return r;
}
void add(int u, int v) {
e[cnt] = {v, head[u]};
head[u] = cnt++;
}
void dfs(int d, int u, int p) {
dep[u] = d;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to;
if (v == p) continue;
sz[u]++;
dfs(d + 1, v, u);
}
if (sz[u] == 0) leaf[u] = 1;
}
void dfs2(int u, int p) {
if (leaf[u]) return;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to;
if (v == p) continue;
dfs2(v, u);
dp[u] = (dp[u] + (1 - dp[v] + M) % M) % M;
}
dp[u] = dp[u] * qpow(sz[u], M - 2, M) % M;
dp[u] = qpow(dp[u], sz[u], M);
// if (u == 1)
// cout << "check: " << dp[u] << ' ' << sz[u] << endl;
dp[u] = (1 - dp[u] + M) % M;
// if (u == 1)
// cout << "check: " << dp[u] << ' ' << sz[u] << endl;
}
int main()
{
time_point<steady_clock> start = steady_clock::now();
int size = 128 << 20;
char * p = (char *)malloc(size) + size;
#if (defined _WIN64) or (defined __unix)
__asm__("movq %0, %%rsp\n" :: "r"(p));
#else
__asm__("movl %0, %%esp\n" :: "r"(p));
#endif
// cout << 39 * qpow(64, M - 2, M) % M << endl;
// cout << 25 * qpow(64, M - 2, M) % M << endl;
int n, u, v;
read(n);
ms(head, -1);
for (int i = 1; i < n; i++) {
read(u, v);
add(u, v);
add(v, u);
}
dfs(1, 1, 0);
int mx = 0;
for (int i = 1; i <= n; i++)
if (leaf[i]) mx = max(mx, dep[i]);
for (int i = 1; i <= n; i++) {
if (leaf[i]) {
if (dep[i] == mx) {
dp[i] = 1;
} else dp[i] = 0;
}
}
dfs2(1, 0);
// cout << "------------------" << endl;
// for (int i = 1; i <= n; i++) cout << i << ' ' << dp[i] << endl;
// cout << "------------------" << endl;
printf("%lld\n", dp[1]);
cerr << endl << "------------------------------" << endl << "Time: "
<< duration<double, milli>(steady_clock::now() - start).count()
<< " ms." << endl;
exit(0);
}