题目
有一棵以
1
1
1为根的树,第
i
i
i个节点的权值是
a
i
a_i
ai,
a
1
=
L
a_1=L
a1=L
求从根往叶子有多少个子序列为
{
L
,
L
−
1
,
.
.
,
1
}
\{L,L-1,..,1\}
{L,L−1,..,1}。
支持
m
m
m次修改,第
i
i
i次将
a
(
i
−
1
)
m
o
d
n
+
1
a_{(i-1)\mod n+1}
a(i−1)modn+1修改成
v
i
v_i
vi
n
≤
1
e
6
,
m
≤
2
e
6
,
f
a
i
<
i
n\leq 1e6,m\leq 2e6,fa_i<i
n≤1e6,m≤2e6,fai<i
思考历程
考虑对一个点修改之后对答案的影响。
记
f
x
f_x
fx为根到
x
x
x路径上多少个子序列是
{
L
,
L
−
1
,
.
.
.
,
a
x
+
1
}
\{L,L-1,...,a_x+1\}
{L,L−1,...,ax+1},
g
x
g_x
gx表示
x
x
x往叶子有多少个子序列为
{
a
x
,
a
x
−
1
,
.
.
.
,
1
}
\{a_x,a_x-1,...,1\}
{ax,ax−1,...,1}
修改之后答案的变化为
−
f
x
g
x
+
f
x
′
g
x
′
-f_xg_x+f'_xg'_x
−fxgx+fx′gx′
问题是如何维护这两个东西。
对于
g
x
g_x
gx,其实就是求
∑
y
在
x
子
树
内
且
a
y
=
a
x
−
1
g
y
\sum_{y在x子树内且a_y=a_x-1}g_y
∑y在x子树内且ay=ax−1gy
于是对于每个不同的
a
x
a_x
ax开个线段树,线段树
k
k
k按照
d
f
s
dfs
dfs序存储所有
a
x
=
k
a_x=k
ax=k的点的
g
x
g_x
gx和。于是就可以快速地查询出
g
g
g的答案。
然而有个问题:修改
g
x
g_x
gx的时候,会影响到权值为
a
x
+
1
a_x+1
ax+1和
a
x
′
+
1
a'_x+1
ax′+1的祖先。于是这个方法就出锅了。
这个方法是不是没有卵用呢?
然而,这题修改的方式很奇怪,它是按照顺序来修改的。将修改分成
⌈
m
n
⌉
\lceil\frac{m}{n} \rceil
⌈nm⌉个部分分开操作,可以发现在每次操作到某个点的时候,它的祖先就都操作过了。所以,没有必要去维护祖先的信息。
既然这样,
f
x
f_x
fx维护的时候也就直接用个桶从祖先维护过来就好了。
这个方法看上去太粗暴?
实际上也有别的做法,线段树合并、dsn on tree等。
时间复杂度
O
(
n
lg
n
)
O(n \lg n)
O(nlgn)
正解
现在的问题是如何快速地计算
g
x
g_x
gx
看起来不能用桶直接从子树搞上来,因为涉及到合并的问题。
但是大佬又提供了一种神奇的思路:直接将
g
g
g值加入桶中,从
x
x
x往下递归之前先记一下,递归回来之后作差,就可以得到子树对
g
x
g_x
gx的贡献。
于是这题就愉快地
O
(
n
)
O(n)
O(n)了。
代码
using namespace std;
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 1000010
#define ll long long
#define mo 1000000007
#define mod(x) ((((x)%=mo)+=mo)%=mo)
int n,m,L;
struct EDGE{
int to;
EDGE *las;
} e[N];
int ne;
EDGE *last[N];
int a[N],v[N];
ll f0[N],g0[N],f1[N],g1[N],gt[N];
ll sub[N],anc[N];
ll ans,sum;
void clear(){memset(sub,0,sizeof(ll)*(n+1));}
void build(int x){
f0[x]=anc[a[x]+1];
f1[x]=anc[v[x]+1];
mod(anc[v[x]]+=f1[x]);
g1[x]=sub[v[x]-1];
for (EDGE *ei=last[x];ei;ei=ei->las)
build(ei->to);
mod(anc[v[x]]-=f1[x]);
mod(g1[x]=((v[x]==1)+sub[v[x]-1]-g1[x]));
mod(sub[v[x]]+=g1[x]);
}
void getgt(int x){
int tmp=sub[a[x]-1];
gt[x]=sub[v[x]-1];
for (EDGE *ei=last[x];ei;ei=ei->las)
getgt(ei->to);
mod(gt[x]=((v[x]==1)+sub[v[x]-1]-gt[x]));
mod(tmp=((a[x]==1)+sub[a[x]-1]-tmp));
mod(sub[a[x]]+=tmp);
}
void move(){
memcpy(f0,f1,sizeof(ll)*(n+1));
memcpy(g0,g1,sizeof(ll)*(n+1));
}
int main(){
// freopen("in.txt","r",stdin);
// freopen("out.txt","w",stdout);
freopen("scholar.in","r",stdin);
freopen("scholar.out","w",stdout);
scanf("%d%d%d",&n,&m,&L);
for (int i=2;i<=n;++i){
int x;
scanf("%d",&x);
e[ne]={i,last[x]};
last[x]=e+ne++;
}
for (int i=1;i<=n;++i)
scanf("%d",&a[i]),v[i]=a[i];
anc[L+1]=1;
build(1);
ans=g1[1];
int d=m/n;
for (int k=0;k<d;++k){
for (int i=1;i<=n;++i)
scanf("%d",&v[i]);
move(),clear();
getgt(1),build(1);
for (int i=1;i<=n;++i){
(ans+=f1[i]*gt[i]-f0[i]*g0[i])%=mo;
(sum+=(ans+mo)*(k*n+i))%=mo;
// printf("%lld\n",ans);
}
memcpy(a,v,sizeof(int)*(n+1));
}
if (m%n){
int r=m%n;
for (int i=1;i<=r;++i)
scanf("%d",&v[i]);
for (int i=r+1;i<=n;++i)
v[i]=a[i];
move(),clear();
getgt(1),build(1);
for (int i=1;i<=r;++i){
(ans+=f1[i]*gt[i]-f0[i]*g0[i])%=mo;
(sum+=(ans+mo)*(d*n+i))%=mo;
// printf("%lld\n",ans);
}
}
printf("%lld\n",sum);
return 0;
}
总结
不要总是把脑子放在数据结构上……
除了数据结构之外有个计数利器叫差分。