雅礼集训DAY5T3

16 篇文章 0 订阅

题意 : n个点的树,每个点有一个为0或1的权值,等概率选择一个点作为起点,然后等概率选择点v,走到点v,将v的权值异或1,当所有点的权值相等时停止,求路径长度的期望值。

 

根据期望的线性性, 我们考虑每一个点对答案的贡献.
每次选择了一个点之后, 如果没有结束, 那么下一步期望的移动距离就是这个点到其他所有点
的距离和除以 n.
容易发现树的形态并不影响点的期望被选择次数. 只要 0 和 1 的个数一定, 所有权值为相同的
点的期望选择次数是相同的.
设 val i,j 表示 i 个 1 的时候权值为 j 的点的期望选择次数, 有下列方程
val[i,0] = 1/n + i / n * val[i-1,0] + (n - i - 1)/n * val [i+1,0] + 1/n*val[i+1,1] (1)
val[i,1] = 1/n + (i - 1)/n * val[i-1,1] + 1 / n * val[i-1,0] + (n - i)/n * val[i+1,1] (2)
其中 val[0,∗] , val[n,∗] val[1,1] val[n-1,0] 是边界情况.
移项之后可以发现根据 val[i,∗] 可以推出 val[i+1,1], 根据 val[i,∗] 和 val[i+1,1] 可以推出 val[i+1,0] 因此
可以将每一个 valk;∗(1 ≤ k < n) 表示成 aval1;0 + bval1;1 + c 的形式, 最后根据在 n - 1 的时候的
(1), (2) 两式列出一个二元一次方程, 解之即可.
最后答案就是每个点的期望选择次数乘上它每一步的期望移动步数. 时间复杂度 O(n).

存方程系数可以用结构体,方便重载运算符。

std代码:

#include <bits/stdc++.h>

#define REP(i, a, b) for (int i = (a), _end_ = (b); i < _end_; ++i)
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define mp make_pair
#define x first
#define y second
#define pb push_back
#define SZ(x) (int((x).size()))
#define ALL(x) (x).begin(), (x).end()

template<typename T> inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }

typedef long long LL;

const int MOD = 1e9 + 7;
const int oo = 0x3f3f3f3f;
const int maxn = 100000;

int fpm(LL b, int e, int m)
{
	b %= m;
	LL t = 1;
	for ( ; e; e >>= 1, (b *= b) %= m)
		if (e & 1) (t *= b) %= m;
	return t;
}

struct data {
	int x, y, z;

	data() { }
	data(int _z): x(0), y(0), z(_z) { }
	data(int _x, int _y, int _z): x(_x), y(_y), z(_z) { }

	friend data operator+(const data &x, const data &y) { return data((x.x + y.x) % MOD, (x.y + y.y) % MOD, (x.z + y.z) % MOD); }
	friend data operator-(const data &x, const data &y) { return data((x.x - y.x) % MOD, (x.y - y.y) % MOD, (x.z - y.z) % MOD); }
	friend data operator*(const data &x, const int &y) { return data((LL)x.x * y % MOD, (LL)x.y * y % MOD, (LL)x.z * y % MOD); }
};

int n;
int w[maxn + 5];
int fa[maxn + 5];
int val[maxn + 5][2];
int sum_down[maxn + 5], sum[maxn + 5];
int sz[maxn + 5];
std::vector<int> children[maxn + 5];
int inv[maxn + 5];

void calc()
{
	REP(i, 1, n + 1) inv[i] = fpm(i, MOD - 2, MOD);
	static data coe[maxn + 5][2];
	coe[0][0] = coe[0][1] = coe[n][0] = coe[n][1] = data(0, 0, 0);
	coe[1][0] = data(1, 0, 0), coe[1][1] = data(0, 1, 0);
	REP(i, 1, n - 1) {
		coe[i + 1][1] = (coe[i][1] - data((i != 1) * inv[n]) - coe[i - 1][1] * ((LL)(i - 1) * inv[n] % MOD) - coe[i - 1][0] * inv[n]) * ((LL)n * inv[n - i] % MOD);
		coe[i + 1][0] = (coe[i][0] - data(inv[n]) - coe[i + 1][1] * inv[n] - coe[i - 1][0] * ((LL)i * inv[n] % MOD)) * ((LL)n * inv[n - i - 1] % MOD);
	}
	data val0 = coe[n - 1][0] - coe[n - 2][0] * ((LL)(n - 1) * inv[n] % MOD);
	data val1 = coe[n - 1][1] - inv[n] - coe[n - 2][1] * ((LL)(n - 2) * inv[n] % MOD) - coe[n - 2][0] * inv[n];
	int det = ((LL)val0.x * val1.y - (LL)val0.y * val1.x) % MOD;
	assert(det);
	det = fpm(det, MOD - 2, MOD);
	int valx = ((LL)-val0.z * val1.y - (LL)val0.y * -val1.z) % MOD * det % MOD;
	int valy = ((LL)val0.x * -val1.z - (LL)-val0.z * val1.x) % MOD * det % MOD;
	REP(i, 0, n + 1) REP(j, 0, 2) val[i][j] = ((LL)valx * coe[i][j].x + (LL)valy * coe[i][j].y + coe[i][j].z) % MOD;
	REP(i, 0, n + 1) REP(j, 0, 2) (val[i][j] += MOD) %= MOD;
}

int main()
{
	freopen("c.in", "r", stdin);
	freopen("c.out", "w", stdout);

	scanf("%d", &n);

    assert(n >= 3);

	calc();

	static char s[maxn + 5];
	scanf("%s", s);
	int cnt = 0;
	REP(i, 0, n) {
        w[i] = s[i] == '1';
        cnt += w[i];
    }
	fa[0] = -1;
	REP(i, 1, n) {
        scanf("%d", fa + i);
        --fa[i];
        children[fa[i]].pb(i);
    }

	for (int i = n - 1; i >= 0; --i) {
		++sz[i];
        if (fa[i] >= 0)
            sz[fa[i]] += sz[i];
		REP(k, 0, SZ(children[i])) {
			int j = children[i][k];
			(sum_down[i] += sum_down[j] + sz[j]) %= MOD;
		}
	}
	memcpy(sum, sum_down, sizeof sum);
	REP(i, 0, n)
		REP(k, 0, SZ(children[i])) {
			int j = children[i][k];
			(sum[j] += (sum[i] - sum_down[j] + n - (sz[j] << 1)) % MOD) %= MOD;
		}

	int ans = 0;
	REP(i, 0, n) (ans += sum[i]) %= MOD;
	ans = (LL)ans * fpm(n, MOD - 2, MOD) % MOD;
	REP(i, 0, n) (ans += (LL)val[cnt][w[i]] * sum[i] % MOD) %= MOD;
	ans = (LL)ans * fpm(n, MOD - 2, MOD) % MOD;
	(ans += MOD) %= MOD;

	printf("%d\n", ans);
}

AC Code:

#include<bits/stdc++.h>
#define maxn 100005
#define LL long long
#define mod 1000000007
using namespace std;

int Pow(int b,int k){ int r=1;for(;k;k>>=1,b=1ll*b*b%mod) if(k&1) r=1ll*r*b%mod; return r;}

int n,inv[maxn];
char s[maxn];
vector<int>G[maxn];
int sm[maxn],sz[maxn],rsm[maxn],g[maxn][2];

struct data{
	LL a,b,c;
	data(LL a=0,LL b=0,LL c=0):a(a),b(b),c(c){}
	data operator +(const data &B)const{ return data(a+B.a,b+B.b,c+B.c); }
	data operator -(const data &B)const{ return data(a-B.a,b-B.b,c-B.c); }
	data operator -(const int &B)const{ return data(a,b,c-B); }
	data operator +(const int &B)const{ return data(a,b,c+B); }
	data operator *(const int &B)const{ return data(a*B%mod,b*B%mod,c*B%mod); }
	void mt(){ a %= mod , b %= mod , c %= mod; }
}f[maxn][2];

void dfs(int u,int ff){
	sz[u] = 1;
	for(int v:G[u]) if(v!=ff)
		dfs(v,u),sz[u]+=sz[v],sm[u]=(1ll*sm[u]+sm[v]+sz[v]) % mod;
}

void dfs2(int u,int ff){
	for(int v:G[u]) if(v!=ff)
		rsm[v] = (1ll * rsm[u] + sm[u] - sm[v] - sz[v] + n - sz[v]) % mod,
		dfs2(v,u);
}

int main(){
	freopen("C.in","r",stdin);
	
	scanf("%d",&n);
	scanf("%s",s+1);
	inv[0] = inv[1] = 1;
	for(int i=2,f;i<=n;i++){
		scanf("%d",&f);
		G[i].push_back(f),G[f].push_back(i);
		inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
	}
	dfs(1,0);dfs2(1,0);
	f[0][0] = f[0][1] = f[n][0] = f[n][1] = data(0,0,0);
	f[1][0] = data(1,0,0) , f[1][1] = data(0,1,0);
	for(int i=1;i<n-1;i++){
		f[i+1][1] = (f[i][1] * n - data(0,0,(i!=1)) - f[i-1][1] * (i-1) - f[i-1][0]) * inv[n-i];
		f[i+1][1].mt();
		f[i+1][0] = (f[i][0] * n - data(0,0,1) - f[i-1][0] * i - f[i+1][1]) * inv[n-i-1];
		f[i+1][0].mt();
	}
	data A = f[n-1][0] - f[n-2][0] * ((n-1ll) * inv[n] % mod) ,
		 B = f[n-1][1] - inv[n] - f[n-2][1] * (n-2) * inv[n] - f[n-2][0] * inv[n];
	A.mt(),B.mt();
	int iv = Pow((1ll * A.a * B.b - 1ll * B.a * A.b) % mod , mod - 2) , b = (1ll * B.a * A.c - 1ll * A.a * B.c) % mod * iv % mod,
		a = (1ll * A.b * B.c - 1ll * B.b * A.c) % mod * iv % mod;
	for(int i=1;i<n;i++)
		g[i][0] = (1ll * f[i][0].a * a + 1ll * f[i][0].b * b + f[i][0].c) % mod,
		g[i][1] = (1ll * f[i][1].a * a + 1ll * f[i][1].b * b + f[i][1].c) % mod;
	int ct = 0;
	for(int i=1;i<=n;i++) if(s[i] == '1') ct++;
	int ans = 0;
	for(int i=1;i<=n;i++) 
		ans = (ans + 1ll * sm[i] + rsm[i]) % mod;
	ans = 1ll * ans * inv[n] % mod;
	for(int i=1;i<=n;i++)
		ans = (ans + 1ll * g[ct][s[i] - '0'] * (sm[i] + rsm[i])) % mod;
	ans = 1ll * ans * inv[n] % mod;
 	printf("%d\n",(ans+mod)%mod);
}

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值