题目出自philipsweng
题目大意
一棵树有 n 个节点,根是 1。每条边有个长度。
对于每个点 i,给出 L[i] 和 R[i],求以点 i 为根的子树中,边数在 [ L[i], R[i] ] 内的路径的最长长度。若不存在则为-1。
输出
∑ni=123333n−iAns[i] mod 998244353
n<=10^6, 边权<=10^9
首先的思路
枚举点对的其中一个点,然后用数据结构找到另一个点,使得答案最大。这个数据结构可以用线段树什么的。
然后这个就可以启发式合并。
第一种打法
正常的启发式合并,自底向上合并。
log^2,并且常数挺大。
第二种打法
像树链剖分一样的那种启发式合并。
先划分轻重链。对于每个点,先递归轻儿子,再递归重儿子,然后枚举轻儿子合并一下。若当前节点是它父亲的重儿子,则信息保留,否则信息清空。(即信息保留的原则是:轻儿子清空,重儿子保留)
log^2,常数小。
优化
注意答案只跟深度有关。
所以划分轻重链时的依据应是深度(即长链剖分)。枚举轻儿子的时候,不必要枚举每一个点,而是枚举每一个深度。
合并是 n 次的,再加上数据结构就是一个log了。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
using namespace std;
typedef long long LL;
const int maxn=1e6+5;
const LL mo=998244353;
int n,l[maxn],r[maxn];
int tot,go[maxn],next[maxn],f1[maxn];
LL val[maxn];
void ins(int x,int y,LL z)
{
go[++tot]=y;
val[tot]=z;
next[tot]=f1[x];
f1[x]=tot;
}
int deep[maxn],size[maxn],Hson[maxn],maxr;
LL a[maxn];
void dfs_size(int k,int last,LL s) //这里的size是以深度为标准的
{
size[k]=deep[k]=deep[last]+1;
maxr=max(maxr,deep[k]);
a[k]=s;
for(int p=f1[k]; p; p=next[p])
{
dfs_size(go[p],k,s+val[p]);
size[k]=max(size[k],size[go[p]]);
if (size[go[p]]>size[Hson[k]]) Hson[k]=go[p];
}
}
LL tr[4*maxn];
bool bz[4*maxn];
void update(int k,int t)
{
if (bz[k]==0) return;
tr[t]=tr[t+1]=-1;
bz[t]=bz[t+1]=1;
bz[k]=0;
}
void tr_xg(int k,int l,int r,int x,LL z)
{
if (l==r)
{
tr[k]=max(tr[k],z);
return;
}
int t=k<<1, t1=(l+r)>>1;
update(k,t);
if (x<=t1) tr_xg(t,l,t1,x,z); else tr_xg(t+1,t1+1,r,x,z);
tr[k]=max(tr[t],tr[t+1]);
}
LL tr_cx(int k,int l,int r,int x,int y)
{
if (x>y) return -1;
if (l==x && r==y) return tr[k];
int t=k<<1, t1=(l+r)>>1;
update(k,t);
if (y<=t1) return tr_cx(t,l,t1,x,y);
else if (x>t1) return tr_cx(t+1,t1+1,r,x,y);
else return max(tr_cx(t,l,t1,x,t1), tr_cx(t+1,t1+1,r,t1+1,y));
}
LL ans;
LL mi(LL x,LL y)
{
LL re=1;
for(; y; y>>=1, x=x*x%mo) if (y&1) re=re*x%mo;
return re;
}
LL sum,nowbh[maxn],depmax[maxn]; //为了不至于每次都清空depmax,于是用时间标记。
void dfs1(int k)
{
depmax[deep[k]]=(nowbh[deep[k]]<sum) ?a[k] :max(depmax[deep[k]],a[k]) ;
nowbh[deep[k]]=sum;
for(int p=f1[k]; p; p=next[p]) dfs1(go[p]);
}
void dfs(int k)
{
for(int p=f1[k]; p; p=next[p]) if (go[p]!=Hson[k])
{
dfs(go[p]);
bz[1]=1;
}
if (Hson[k])
{
dfs(Hson[k]);
LL ans1=-1;
LL t=tr_cx(1,1,maxr,l[k]+deep[k],min(r[k]+deep[k],maxr));
if (t>-1) ans1=t-a[k];
tr_xg(1,1,maxr,deep[k],a[k]);
for(int p=f1[k]; p; p=next[p]) if (go[p]!=Hson[k])
{
++sum;
dfs1(go[p]);
fo(i,deep[go[p]],size[go[p]])
{
LL t=tr_cx(1,1,maxr,max(1,l[k]+2*deep[k]-i),min(r[k]+2*deep[k]-i,maxr));
if (t>-1) ans1=max(ans1,depmax[i]+t-a[k]*2);
}
fo(i,deep[go[p]],size[go[p]]) tr_xg(1,1,maxr,i,depmax[i]);
}
ans1%=mo;
ans=(ans+mi(23333,n-k)*ans1%mo+mo)%mo;
} else
{
tr_xg(1,1,maxr,deep[k],a[k]);
ans=(ans-mi(23333,n-k)%mo+mo)%mo;
}
}
int main()
{
scanf("%d",&n);
fo(i,1,n) scanf("%d %d",&l[i],&r[i]);
fo(i,2,n)
{
int u; LL c;
scanf("%d %lld",&u,&c);
ins(u,i,c);
}
dfs_size(1,0,0);
memset(tr,255,sizeof(tr));
memset(bz,255,sizeof(bz));
dfs(1);
printf("%lld\n",ans);
}