P5305 [GXOI/GZOI2019]旧词
题解:
50000 个 询 问 ( x , y ) 50000个询问(x,y) 50000个询问(x,y)
∑ i ≤ x d e p t h ( l c a ( i , y ) ) k \sum_{i≤x} depth(lca(i,y))^k ∑i≤xdepth(lca(i,y))k
若
k
=
1
k=1
k=1
虽然询问区间可以按照
x
x
x 排序成升序,但是
y
y
y 不一样
x
,
y
x,y
x,y 都是50000,同时枚举一定会炸
\text{}
\text{}
有个提高组难度 的思想:
d
e
p
[
l
c
a
]
=
dep[lca]=
dep[lca]= 二者到根的路径 的 公共节点数
那么 ,
对于每一个点 x 的贡献,就是把 1->x 的路径上的点加1
对于每一个询问 y ,就是求出 1->y 的路径点权和
这个时候再把询问区间按照 x 排序就好了
核心思想就是一个差分:将答案
d
e
p
[
l
c
a
]
dep[lca]
dep[lca] 分成 给
d
e
p
[
l
c
a
]
dep[lca]
dep[lca] 个点每个+1,路径点权和其实就是前缀和
就是给深度为
d
e
p
[
i
]
dep[i]
dep[i] 的点 加上权值
d
e
p
[
i
]
1
−
(
d
e
p
[
i
]
−
1
)
1
=
1
dep[i]^1-(dep[i]-1)^1=1
dep[i]1−(dep[i]−1)1=1
最后用一个树剖维护就好了
\text{}
\text{}
考虑
k
≠
1
k≠1
k=1
同样的,我们再考虑一个差分:将答案
d
e
p
[
l
c
a
]
k
dep[lca]^k
dep[lca]k 分成
给
x
x
x 到根路径上每个点
i
i
i 加值,即给深度为
d
e
p
[
i
]
dep[i]
dep[i] 的点 加上权值
d
e
p
[
i
]
k
−
(
d
e
p
[
i
]
−
1
)
k
dep[i]^k-(dep[i]-1)^k
dep[i]k−(dep[i]−1)k 显然
y
y
y 到根路径点权和就是
d
e
p
[
l
c
a
]
k
dep[lca]^k
dep[lca]k
解释:
d
e
p
[
i
]
k
−
(
d
e
p
[
i
]
−
1
)
k
+
(
d
e
p
[
i
]
−
1
)
k
−
(
d
e
p
[
i
]
−
2
)
k
+
.
.
.
.
.
.
.
.
dep[i]^k-(dep[i]-1)^k+(dep[i]-1)^k-(dep[i]-2)^k+........
dep[i]k−(dep[i]−1)k+(dep[i]−1)k−(dep[i]−2)k+........
\text{}
但是我们不能直接给 不同的点加不同的值
但我们发现每个点被加的值是固定的,一定为
c
=
d
e
p
[
x
]
k
−
(
d
e
p
[
x
]
−
1
)
k
c=dep[x]^k-(dep[x]-1)^k
c=dep[x]k−(dep[x]−1)k
\text{}
给这条链上的点各自加值,相当于给每个点的覆盖次数
c
n
t
+
1
cnt+1
cnt+1
那么用线段树维护 每个点的贡献就是
c
n
t
×
c
cnt \times c
cnt×c
总感觉我像是写给小朋友看的一样。。。对差分的理解不够深刻啊
差分与前缀和互为逆运算
#include<cstdio>
#include<algorithm>
#define LL long long
const int N=50100;
const LL mod=998244353;
using namespace std;
struct bian{int y,gg;}b[N];
int first[N],len=0,n,Q,k;
void ins(int x,int y){
b[++len].y=y;
b[len].gg=first[x];
first[x]=len;
}
int read(){
int x=0;char c=getchar();
while(c<48) c=getchar();
while(c>47) x=x*10+c-'0',c=getchar();
return x;
}
//================================================================================
int tot[N],son[N],fa[N];LL dep[N];
void dfs(int x)
{
dep[x]=dep[fa[x]]+1; tot[x]=1;
for(int i=first[x];i>0;i=b[i].gg)
{
int y=b[i].y;
if(y!=fa[x])
{
dfs(y);
if(!son[x] || tot[y]>tot[son[x]]) son[x]=y;
tot[x]+=tot[y];
}
}
}
int fre[N],rev[N],top[N],trlen=0;
void dfs2(int x,int tp)
{
fre[x]=++trlen; top[x]=tp;
if(son[x]) dfs2(son[x],tp);
for(int i=first[x];i>0;i=b[i].gg){
int y=b[i].y;
if(y!=fa[x] && y!=son[x])
dfs2(y,y);
}
}
//================================================================================
struct tree{LL c,cnt,tr;}tr[N<<2];
void chan(int now,int l,int r,int x,LL k)
{
if(l==r){
tr[now].c+=k;
tr[now].c%=mod;
return;
}
int mid=(l+r)>>1;
if(x<=mid) chan(now<<1,l,mid,x,k);
else chan(now<<1|1,mid+1,r,x,k);
tr[now].c=tr[now<<1].c + tr[now<<1|1].c;
tr[now].c=(tr[now].c+mod)%mod;
}
void pushdown(int now){
if(!tr[now].cnt) return ;
LL cnt1=tr[now].cnt;
tr[now<<1].tr+=cnt1*tr[now<<1].c;
tr[now<<1|1].tr+=cnt1*tr[now<<1|1].c;
tr[now<<1].cnt+=cnt1; tr[now<<1|1].cnt+=cnt1;
tr[now].cnt=0;
}
void change(int now,int l,int r,int L,int R,LL k)
{
if(l>R || r<L) return ;
if(L<=l && r<=R){
tr[now].cnt+=k;
tr[now].tr+=tr[now].c*k;
tr[now].cnt=(tr[now].cnt+mod)%mod;
tr[now].tr=(tr[now].tr+mod)%mod;
return;
}
pushdown(now);
int mid=(l+r)>>1;
change(now<<1,l,mid,L,R,k);
change(now<<1|1,mid+1,r,L,R,k);
tr[now].tr=(tr[now<<1].tr + tr[now<<1|1].tr+mod)%mod;
}
LL find(int now,int l,int r,int L,int R)
{
if(l>R || r<L) return 0;
if(L<=l && r<=R) return tr[now].tr;
pushdown(now);
int mid=(l+r)>>1;
return (find(now<<1,l,mid,L,R) + find(now<<1|1,mid+1,r,L,R)+mod)%mod;
}
//================================================================================
LL fsolve(int x,int y)
{
int tx=top[x],ty=top[y];LL ans=0;
while(tx!=ty)
{
if(dep[tx]>dep[ty]) swap(tx,ty);
ans=(ans + find(1,1,trlen,fre[ty],fre[y]))%mod;
y=fa[ty];ty=top[y];
}
if(dep[x]>dep[y]) swap(x,y);
return (ans+find(1,1,trlen,fre[x],fre[y]))%mod;
}
void csolve(int x,int y)
{
int tx=top[x],ty=top[y];
while(tx!=ty)
{
if(dep[tx]>dep[ty]) swap(tx,ty);
change(1,1,trlen,fre[ty],fre[y],1);
y=fa[ty];ty=top[y];
}
if(dep[x]>dep[y]) swap(x,y);
change(1,1,trlen,fre[x],fre[y],1);
}
//==================================================================================
struct node{int i,x,y;}st[N];
bool cmp(node a,node b) { return a.x<b.x; }
LL ksm(LL d,int p){
LL ans=1;
while(p>0)
{
if(p%2==1) ans=ans*d%mod;
d=d*d%mod;
p>>=1;
}return ans;
}
LL an[N];
//===================================================================================
int main()
{
n=read(),Q=read(),k=read();
for(int i=2;i<=n;i++) fa[i]=read(),ins(fa[i],i);
dfs(1);dfs2(1,1);
LL c;
for(int i=1;i<=n;i++) c=ksm(dep[i],k)-ksm((dep[i]-1ll),k) , chan(1,1,trlen,fre[i],c);//计算权值
for(int i=1;i<=Q;i++){
st[i].i=i;
st[i].x=read();
st[i].y=read();
}
sort(st+1,st+1+Q,cmp);
int now=0;
for(int i=1;i<=Q;i++){
while(now<st[i].x) csolve(1,++now);//给1~x链cnt+1
an[st[i].i]=fsolve(1,st[i].y);
}
for(int i=1;i<=Q;i++)printf("%lld\n",an[i]);
return 0;
}