Fibonacci Numbers on Tree
时间限制: 3 Sec 内存限制: 512 MB
题目描述
在数学中,斐波那契数列F[N]由以下递归关系确定:F[N] = F[N-1] + F[N-2],以及边界 条件 F[1]=1,F[2]=1。
今天,大厨给了一棵N个结点的树。结点从1到N编号并且1是根节点。初始时,每个结点的权值均为0。接下来,大厨要求你处理M个操作。
每条操作形如:
1 A x y:将x到y路径上第k个结点的权值增加F[k],即x结点的权值增加F[1],x到y
路径上的第二个结点增加F[2])。
2 QS x y:视x为根,询问以y结点为根的子树中所有结点的权值和。
3 QC x y:询问x到y路径上所有结点的权值和。
4 R x:将所有结点的权值还原到第x个操作后的状态。如果x = 0,那么回到初始状态,也就是
所有结点的权值清0。
输入
输入数据的第一行包含两个整数,分别表示N和M
接下来的N-1行,每行表示树的一条x和y之间的边
接下来包含M行,每行包含一个操作
A x1 y
QS x1 y
QC x1 y
R x1
每个操作中x并不是直接给出的。
实际的x将会是每个操作中x1 xor lastans的结果。
lastans表示上一次询问的答案。初始时lastans = 0。
输出
对于每个QS和QC,输出一行表示答案。答案可能很大,输出模1000000009之后的结果。
样例输入
5 6
1 2
1 3
3 4
3 5
A 4 2
A 2 5
QS 4 3
QC 12 4
R 6
QC 6 4
样例输出
13
7
4
数据规模
对于20%的数据:1 <= N,M <= 10^4
对于100%的数据:1 <= N,M <= 10^5,1 <= x,y <= N
来源
Codechef Sept14
题解
先考虑在数列上怎么写。
显然多个斐波那契数列相加还是斐波那契数列。
所以只要知道一个数列额前2项就可以推出整个数列。
使用线段树,每个区间维护该区间数列的前两项。
然后,这道题强行被搞到了树上。
然后,这道题被强行加了可持久化。
然后,这道题就从3k题莫名其妙地变成了6k题。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<string>
#include<set>
#include<algorithm>
#define ll long long
#define mod 1000000009
#define N 100010
using namespace std;
int n,m,cnt,tot,pre,l[N],r[N],f[N],A[N],B[N],sa[N],sb[N];
int fa[N],dep[N],size[N],next[N],flag[N],top[N];
int k,la[N],ff[N*2],q[N],rt[N];
struct node{int a,b;}map[N*2];
struct info{int lc,rc,f1,f2,f3,f4,sum;}t[N*100];
int get(int x){return x^pre;}
void add(int a,int b)
{
map[++k]=(node){a,b};ff[k]=la[a];la[a]=k;
map[++k]=(node){b,a};ff[k]=la[b];la[b]=k;
}
void bfs()
{
int l=1,r=2;q[1]=1;dep[1]=1;
while(l<r)
{
int x=q[l];size[x]=1;l++;
for(int a=la[x];a;a=ff[a])
if(fa[x]!=map[a].b)
{
q[r]=map[a].b;fa[q[r]]=x;
dep[q[r]]=dep[x]+1;r++;
}
}
for(int i=r-1;i;i--)
{
int x=q[i];size[fa[x]]+=size[x];
if(size[next[fa[x]]]<size[x])next[fa[x]]=x;
}
}
void dfs(int x,int val)
{
l[x]=++tot;top[x]=val;
if(next[x])dfs(next[x],val);
for(int a=la[x];a;a=ff[a])
if(fa[x]!=map[a].b&&map[a].b!=next[x])
dfs(map[a].b,map[a].b);
r[x]=tot;
}
class seg_tree
{
void update(int x)
{
int lc=t[x].lc,rc=t[x].rc;
t[x].sum=(t[lc].sum+t[rc].sum)%mod;
}
void pushdown(int x,int l,int r)
{
if(!t[x].f1&&!t[x].f2&&!t[x].f3&&!t[x].f4)return;
int mid=l+r>>1,lc=t[x].lc,rc=t[x].rc;
int ls=mid-l+1,rs=r-mid;
t[x].lc=++cnt;t[cnt]=t[lc];lc=cnt;
t[x].rc=++cnt;t[cnt]=t[rc];rc=cnt;
if(t[x].f1||t[x].f2)
{
t[lc].f1=(t[lc].f1+t[x].f1)%mod;
t[lc].f2=(t[lc].f2+t[x].f2)%mod;
t[lc].sum=(t[lc].sum+((ll)A[ls]*t[x].f1%mod+(ll)B[ls]*t[x].f2%mod)%mod)%mod;
int t1=((ll)sa[ls+1]*t[x].f1%mod+(ll)sb[ls+1]*t[x].f2%mod)%mod;
int t2=((ll)sa[ls+2]*t[x].f1%mod+(ll)sb[ls+2]*t[x].f2%mod)%mod;
t[rc].f1=(t[rc].f1+t1)%mod;
t[rc].f2=(t[rc].f2+t2)%mod;
t[rc].sum=(t[rc].sum+((ll)A[rs]*t1%mod+(ll)B[rs]*t2%mod)%mod)%mod;
t[x].f1=0;t[x].f2=0;
}
if(t[x].f3||t[x].f4)
{
t[rc].f3=(t[rc].f3+t[x].f3)%mod;
t[rc].f4=(t[rc].f4+t[x].f4)%mod;
t[rc].sum=(t[rc].sum+((ll)A[rs]*t[x].f3%mod+(ll)B[rs]*t[x].f4%mod)%mod)%mod;
int t3=((ll)sa[rs+1]*t[x].f3%mod+(ll)sb[rs+1]*t[x].f4%mod)%mod;
int t4=((ll)sa[rs+2]*t[x].f3%mod+(ll)sb[rs+2]*t[x].f4%mod)%mod;
t[lc].f3=(t[lc].f3+t3)%mod;
t[lc].f4=(t[lc].f4+t4)%mod;
t[lc].sum=(t[lc].sum+((ll)A[ls]*t3%mod+(ll)B[ls]*t4%mod)%mod)%mod;
t[x].f3=0;t[x].f4=0;
}
}
public:
void modify1(int &x,int l,int r,int ql,int qr,int st)
{
t[++cnt]=t[x];x=cnt;
if(ql<=l&&r<=qr)
{
int len=r-l+1;
t[x].f1=(t[x].f1+f[st])%mod;
t[x].f2=(t[x].f2+f[st+1])%mod;
t[x].sum=(t[x].sum+((ll)A[len]*f[st]%mod+(ll)B[len]*f[st+1]%mod)%mod)%mod;
return;
}
int mid=l+r>>1;
pushdown(x,l,r);
if(ql<=mid)modify1(t[x].lc,l,mid,ql,qr,st);
if(qr>mid)modify1(t[x].rc,mid+1,r,ql,qr,st+max(0,mid-max(l,ql)+1));
update(x);
}
void modify2(int &x,int l,int r,int ql,int qr,int st)
{
t[++cnt]=t[x];x=cnt;
if(ql<=l&&r<=qr)
{
int len=r-l+1;
t[x].f3=(t[x].f3+f[st])%mod;
t[x].f4=(t[x].f4+f[st+1])%mod;
t[x].sum=(t[x].sum+((ll)A[len]*f[st]%mod+(ll)B[len]*f[st+1]%mod)%mod)%mod;
return;
}
int mid=l+r>>1;
pushdown(x,l,r);
if(ql<=mid)modify2(t[x].lc,l,mid,ql,qr,st+max(0,min(r,qr)-mid));
if(qr>mid)modify2(t[x].rc,mid+1,r,ql,qr,st);
update(x);
}
int qry(int x,int l,int r,int ql,int qr)
{
if(ql<=l&&r<=qr)return t[x].sum;
int mid=l+r>>1,res=0;
pushdown(x,l,r);
if(ql<=mid)res+=qry(t[x].lc,l,mid,ql,qr);
if(qr>mid)res+=qry(t[x].rc,mid+1,r,ql,qr);
return res%mod;
}
void build(int &x,int l,int r)
{
x=++cnt;
if(l==r)return;
int mid=l+r>>1;
build(t[x].lc,l,mid);
build(t[x].rc,mid+1,r);
}
}T;
int lca(int x,int y)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
x=fa[top[x]];
}
if(dep[x]<dep[y])return x;
return y;
}
void modify(int pos)
{
int x,y,p,st,end;
scanf("%d%d",&x,&y);p=lca(x=get(x),y);
st=1;end=dep[x]+dep[y]-dep[p]*2+1;
while(top[x]!=top[p])
{
T.modify2(rt[pos],1,n,l[top[x]],l[x],st);
st+=l[x]-l[top[x]]+1;x=fa[top[x]];
}
while(top[y]!=top[p])
{
end-=l[y]-l[top[y]]+1;
T.modify1(rt[pos],1,n,l[top[y]],l[y],end+1);
y=fa[top[y]];
}
if(l[x]<l[y])T.modify1(rt[pos],1,n,l[x],l[y],st);
else T.modify2(rt[pos],1,n,l[y],l[x],st);
}
void change(int pos)
{
int x;
scanf("%d",&x);
rt[pos]=rt[get(x)];
}
int find(int x,int des)
{
int res=0;
while(top[x]!=top[des])res=top[x],x=fa[top[x]];
if(x!=des)res=next[des];
return res;
}
int qry1(int pos)
{
int x,y;
scanf("%d%d",&x,&y);x=get(x);
if(x==y)return t[rt[pos]].sum;
if(l[y]<=l[x]&&r[x]<=r[y])
{
int p=find(x,y);
int res=t[rt[pos]].sum-T.qry(rt[pos],1,n,l[p],r[p]);
if(res<0)res+=mod;return res;
}
return T.qry(rt[pos],1,n,l[y],r[y]);
}
int qry2(int pos)
{
int x,y,res=0;
scanf("%d%d",&x,&y);x=get(x);
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
res+=T.qry(rt[pos],1,n,l[top[x]],l[x]);
if(res>=mod)res-=mod;
x=fa[top[x]];
}
if(l[x]>l[y])swap(x,y);
res+=T.qry(rt[pos],1,n,l[x],l[y]);
if(res>=mod)res-=mod;
return res;
}
int main()
{
int a,b;char s[2];
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
scanf("%d%d",&a,&b),add(a,b);
f[1]=1;for(int i=2;i<=n;i++)f[i]=(f[i-1]+f[i-2])%mod;
sa[1]=1;sb[1]=0;sa[2]=0;sb[2]=1;
for(int i=3;i<=n;i++)
sa[i]=(sa[i-1]+sa[i-2])%mod,sb[i]=(sb[i-1]+sb[i-2])%mod;
for(int i=1;i<=n;i++)
A[i]=(A[i-1]+sa[i])%mod,B[i]=(B[i-1]+sb[i])%mod;
bfs();dfs(1,1);T.build(rt[0],1,n);
for(int i=1;i<=m;i++)
{
scanf(" %s",s);rt[i]=rt[i-1];
if(s[0]=='A')modify(i);
if(s[0]=='R')change(i);
if(s[0]=='Q')
{
if(s[1]=='S')printf("%d\n",pre=qry1(i));
else printf("%d\n",pre=qry2(i));
}
}
return 0;
}