记录每个点可能取到的所有权值和概率(种数显然一定不大于子树叶子数)
那么如果动态记录,同一时刻可能会用到的空间也不大于子树叶子数
然后线段树又loglog所以就算不搞垃圾回收也是可以接受的(
对于一个点,如果取最大值
对左儿子权值累加右儿子小的概率
对右儿子权值累加左儿子小的概率
即可。
树高并不一定是log (我发现我现在看见树经常下意识感觉树高
log
n
\log n
logn 没救了
所以是要用线段树的(当然不用应该也可以,不过太麻烦了一点
然后这道题线段树合并的复杂度是稳定哒
O
(
n
log
n
)
O(n\log n)
O(nlogn)
很显然,因为权值保证不同,所以合并的线段树没有重叠部分
50pts
O(n^2) (如果保证数据随机则是 O(nlogn) 的)
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstdlib>
#include<cmath>
#include<ctime>
#include<cctype>
using namespace std;
const int MAXN = 300005;
const long long MOD = 998244353;
#define getchar() (frS==frT&&(frT=(frS=frBB)+fread(frBB,1,1<<12,stdin),frS==frT)?EOF:*frS++)
char frBB[1<<12], *frS=frBB, *frT=frBB;
inline void read(int& x)
{
static char ch;
static bool w;
w = 0; x = 0;
ch = getchar();
while (!isdigit(ch)) w |= (ch == '-'), ch = getchar();
while (isdigit(ch)) x = x * 10 + (ch ^ 48), ch = getchar();
if (w) x = -x;
}
int n, tot, LstTot;
int fa[MAXN];
long long val[MAXN];
long long Val[MAXN*20];
long long Pro[MAXN*20];
int Lef[MAXN], Rig[MAXN];
int Child[MAXN][2];
long long Ans = 0;
long long qpow(long long bas, long long ind)
{
long long ret = 1;
bas %= MOD;
while (ind)
{
if (ind & 1)
{
ret *= bas;
ret %= MOD;
}
bas *= bas;
bas %= MOD;
ind >>= 1;
}
return ret;
}
const long long inv = qpow(10000, MOD - 2);
const long long invinv = inv * inv % MOD;
#define LChild(x) Child[x][0]
#define RChild(x) Child[x][1]
#define Adjust(x) (((x%MOD)+MOD)%MOD)
void dfs(const int& Pos)
{
if (!LChild(Pos))
{
Lef[Pos] = Rig[Pos] = ++LstTot;
Val[LstTot] = val[Pos];
Pro[LstTot] = 1;
return;
}
else dfs(LChild(Pos));
if (RChild(Pos)) dfs(RChild(Pos));
else
{
Lef[Pos] = Lef[LChild(Pos)];
Rig[Pos] = Rig[LChild(Pos)];
return;
}
int lpt = Lef[LChild(Pos)], rpt = Lef[RChild(Pos)];
long long sumR = 0, sumL = 0;
Lef[Pos] = LstTot + 1;
while (lpt <= Rig[LChild(Pos)] && rpt <= Rig[RChild(Pos)])
{
if (Val[lpt] < Val[rpt])
{
Val[Rig[Pos]=++LstTot] = Val[lpt];
Pro[Rig[Pos]] = (val[Pos] * sumR % MOD + Adjust(1 - val[Pos]) * Adjust(1 - sumR) % MOD) % MOD * Pro[lpt] % MOD;
sumL += Pro[lpt];
sumL %= MOD;
++lpt;
}
else if (Val[rpt] < Val[lpt])
{
Val[Rig[Pos]=++LstTot] = Val[rpt];
Pro[Rig[Pos]] = (val[Pos] * sumL % MOD + Adjust(1 - val[Pos]) * Adjust(1 - sumL) % MOD) % MOD * Pro[rpt] % MOD;
sumR += Pro[rpt];
sumR %= MOD;
++rpt;
}
}
while (lpt <= Rig[LChild(Pos)])
{
Val[Rig[Pos]=++LstTot] = Val[lpt];
Pro[Rig[Pos]] = val[Pos] * sumR % MOD * Pro[lpt] % MOD;
++lpt;
}
while (rpt <= Rig[RChild(Pos)])
{
Val[Rig[Pos]=++LstTot] = Val[rpt];
Pro[Rig[Pos]] = val[Pos] * sumL % MOD * Pro[rpt] % MOD;
++rpt;
}
}
int main()
{
read(n);
for (register int i = 1; i <= n; ++i)
{
read(fa[i]);
if (!fa[i]) continue;
if (!LChild(fa[i])) LChild(fa[i]) = i;
else RChild(fa[i]) = i;
}
for (register int t, i = 1; i <= n; ++i)
{
read(t);
val[i] = 1ll * t;
if (LChild(i)) val[i] = val[i] * inv % MOD;
}
dfs(1);
long long j = 1;
for (register int i = Lef[1]; i <= Rig[1]; ++i)
{
Ans = (Ans + (Val[i] * j % MOD * Pro[i] % MOD * Pro[i] % MOD)) % MOD;
++j;
}
printf("%lld", Ans);
return 0;
}
100pts
稳定的 O(nlogn)
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstdlib>
#include<cmath>
#include<ctime>
#include<cctype>
using namespace std;
const int MAXN = 300005;
const long long MOD = 998244353;
#define getchar() (frS==frT&&(frT=(frS=frBB)+fread(frBB,1,1<<12,stdin),frS==frT)?EOF:*frS++)
char frBB[1<<12], *frS=frBB, *frT=frBB;
inline void read(int& x)
{
static char ch;
static bool w;
w = 0; x = 0;
ch = getchar();
while (!isdigit(ch)) w |= (ch == '-'), ch = getchar();
while (isdigit(ch)) x = x * 10 + (ch ^ 48), ch = getchar();
if (w) x = -x;
}
int n, tot, SegTot, m;
int fa[MAXN];
long long vls[MAXN];
long long a[MAXN];
long long Val[MAXN*20];
long long Pro[MAXN*20];
long long Mark[MAXN*20];
int Child[MAXN][2];
int SegChild[MAXN*20][2];
int Rt[MAXN];
#define LC(x) SegChild[x][0]
#define RC(x) SegChild[x][1]
long long Ans = 0;
inline long long qpow(long long bas, long long ind)
{
long long ret = 1;
bas %= MOD;
while (ind)
{
if (ind & 1)
{
ret *= bas;
ret %= MOD;
}
bas *= bas;
bas %= MOD;
ind >>= 1;
}
return ret;
}
const long long inv = qpow(10000, MOD - 2);
#define LChild(x) Child[x][0]
#define RChild(x) Child[x][1]
#define Adjust(x) (((x%MOD)+MOD)%MOD)
void Insert(int& Pos, const int& L, const int& R, const int& Vl)
{
Pos = ++SegTot;
Mark[Pos] = Pro[Pos] = 1;
if (L == R) return;
int Mid = L + R >> 1;
if (Vl <= Mid) Insert(LC(Pos), L, Mid, Vl);
else Insert(RC(Pos), Mid + 1, R, Vl);
}
long long LeftSm, RightSm;
void PushDown(const int& x)
{
if (Mark[x] > 1)
{
Pro[x] = Pro[x] * Mark[x] % MOD;
Mark[LC(x)] = Mark[LC(x)] * Mark[x] % MOD;
Mark[RC(x)] = Mark[RC(x)] * Mark[x] % MOD;
Mark[x] = 1;
}
}
int Merge(const int& x, const int& y, const long long& Temp)
{
if (!x && !y) return 0;
PushDown(x); PushDown(y);
if (!x)
{
RightSm += Pro[y];
RightSm %= MOD;
Mark[y] = Mark[y] * ((Temp*(LeftSm)%MOD+Adjust(1-Temp)*Adjust(1-LeftSm)%MOD)%MOD) % MOD;
Mark[y] %= MOD;
PushDown(y);
return y;
}
if (!y)
{
LeftSm += Pro[x];
LeftSm %= MOD;
Mark[x] = Mark[x] * ((Temp*(RightSm)%MOD+Adjust(1-Temp)*Adjust(1-RightSm)%MOD)%MOD) % MOD;
Mark[x] %= MOD;
PushDown(x);
return x;
}
LC(x) = Merge(LC(x), LC(y), Temp);
RC(x) = Merge(RC(x), RC(y), Temp);
Pro[x] = (Pro[LC(x)] + Pro[RC(x)]) % MOD;
return x;
}
void dfs(const int& Pos)
{
if (!LChild(Pos))
{
Insert(Rt[Pos], 1, m, a[Pos]);
return;
}
if (!RChild(Pos))
{
dfs(LChild(Pos));
Rt[Pos] = Rt[LChild(Pos)];
return;
}
dfs(LChild(Pos));
dfs(RChild(Pos));
LeftSm = RightSm = 0;
Rt[Pos] = Merge(Rt[LChild(Pos)], Rt[RChild(Pos)], a[Pos]);
}
long long cnt;
void Calc(const int& Pos, const int& L, const int& R)
{
if (!Pro[Pos]) return;
PushDown(Pos);
if (L == R)
{
++cnt;
Ans = (Ans + cnt * vls[L] % MOD * Pro[Pos] % MOD * Pro[Pos] % MOD) % MOD;
return;
}
int Mid = L + R >> 1;
Calc(LC(Pos), L, Mid);
Calc(RC(Pos), Mid + 1, R);
}
int main()
{
read(n);
for (register int i = 1; i <= n; ++i)
{
read(fa[i]);
if (!fa[i]) continue;
if (!LChild(fa[i])) LChild(fa[i]) = i;
else RChild(fa[i]) = i;
}
for (register int t, i = 1; i <= n; ++i)
{
read(t);
a[i] = 1ll * t;
if (LChild(i)) a[i] = a[i] * inv % MOD;
else vls[++m] = a[i];
}
sort(vls + 1, vls + 1 + m);
m = unique(vls + 1, vls + 1 + m) - vls - 1;
for (register int i = 1; i <= n; ++i) if (!LChild(i)) a[i] = lower_bound(vls + 1, vls + 1 + m, a[i]) - vls;
dfs(1);
Calc(Rt[1], 1, m);
printf("%lld", Ans);
return 0;
}