LOJ #2537. 「PKUWC2018」Minimax 线段树合并

版权声明:想转就转吧,注明出处就行 括弧笑 https://blog.csdn.net/BlackJack_/article/details/80779287



这个题的 n^2 dp 是很显然的 线段树优化dp 也是很显然的

这个题的价值在于增加线段树合并技能熟练度


#include<cmath>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<iomanip>
#include<vector>
#include<string>
#include<bitset>
#include<queue>
#include<map>
#include<set>
using namespace std;

typedef double db;
typedef long long ll;
typedef unsigned int uint;

inline int read()
{
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch<='9'&&ch>='0'){x=10*x+ch-'0';ch=getchar();}
	return x*f;
}
void print(int x)
{if(x<0)putchar('-'),x=-x;if(x>=10)print(x/10);putchar(x%10+'0');}

const int N=300100,mod=998244353;

inline int qpow(int x,int y)
{
	int res(1);
	while(y)
	{
		if(y&1) res=(ll)x*res%mod;
		x=(ll)x*x%mod;
		y>>=1;
	}
	return res;
}

int ch[N][2];

int tot;
int V[N],P[N];

struct president_tree{int ls,rs,sum,tag;}tr[N*30];
int root[N],sz;

inline void pushdown(int k)
{
	if(tr[k].tag)
	{
		int tag=tr[k].tag,ls=tr[k].ls,rs=tr[k].rs;
		tr[ls].sum=(ll)tr[ls].sum*tag%mod,
		tr[rs].sum=(ll)tr[rs].sum*tag%mod,
		tr[ls].tag=(ll)tr[ls].tag*tag%mod,
		tr[rs].tag=(ll)tr[rs].tag*tag%mod,
		tr[k].tag=1;
	}
}

void insert(int &k,int l,int r,int x,int val)
{
	k=++sz;
	tr[k].tag=1;
	tr[k].sum+=val;
	if(l==r) return ;
	int mid=(l+r)>>1;
	x<=mid ? insert(tr[k].ls,l,mid,x,val) : insert(tr[k].rs,mid+1,r,x,val);
}

int merger(int x,int y,int sum_x,int sum_y,int u)
{
	if(!x)
	{
		tr[y].sum=(ll)tr[y].sum*sum_x%mod,
		tr[y].tag=(ll)tr[y].tag*sum_x%mod;
		return y;
	}
	if(!y)
	{
		tr[x].sum=(ll)tr[x].sum*sum_y%mod,
		tr[x].tag=(ll)tr[x].tag*sum_y%mod;
		return x;
	}
	int val_x[2],val_y[2];
	pushdown(x),pushdown(y);
	val_x[0]=tr[tr[x].ls].sum,
	val_x[1]=tr[tr[x].rs].sum,
	val_y[0]=tr[tr[y].ls].sum,
	val_y[1]=tr[tr[y].rs].sum;
	tr[x].ls=merger(tr[x].ls,tr[y].ls,(sum_x+(ll)val_x[1]*(1+mod-P[u]))%mod,(sum_y+(ll)val_y[1]*(1+mod-P[u]))%mod,u);
	tr[x].rs=merger(tr[x].rs,tr[y].rs,(sum_x+(ll)val_x[0]*P[u])%mod,(sum_y+(ll)val_y[0]*P[u])%mod,u);
	tr[x].sum=(tr[tr[x].ls].sum+tr[tr[x].rs].sum)%mod;
	return x;
}

void dfs(int u)
{
	if(!u) return ;
	if(!ch[u][0])
	{
		insert(root[u],1,tot,lower_bound(V+1,V+1+tot,P[u])-V,1);
		return ;
	}
	dfs(ch[u][0]),dfs(ch[u][1]);
	if(!ch[u][1]) root[u]=root[ch[u][0]];
	else root[u]=merger(root[ch[u][0]],root[ch[u][1]],0,0,u);
}

int ans(0);

void cal(int k,int l,int r)
{
	if(l==r)
	{
		(ans+=(ll)l*V[l]%mod*tr[k].sum%mod*tr[k].sum%mod)%=mod;
		return ;
	}
	pushdown(k);
	int mid=(l+r)>>1;
	cal(tr[k].ls,l,mid),cal(tr[k].rs,mid+1,r);
}

int main()
{
	int n=read();
	register int i,x;
	for(i=1;i<=n;++i)
		x=read(),ch[x][ch[x][0] ? 1 : 0]=i;
	int inv_w=qpow(10000,mod-2);
	for(i=1;i<=n;++i)
		P[i]=read(),
		ch[i][0] ? P[i]=(ll)P[i]*inv_w%mod : V[++tot]=P[i];
	sort(V+1,V+1+tot);
	dfs(1);
	cal(root[1],1,tot);
	cout<<ans<<endl;
	return 0;
}

展开阅读全文

没有更多推荐了,返回首页