题意
给定具有 n 个节点和 q 个操作的树,有两种操作。
1 a b :对于链<a, b>,将x2的值增加到该链上的第x个点
例如从a到b的链=(x1,x2,x3,x4,x5),操作后x1+=1,x2+=4,x3+=9,x4+=16,x5+=25
2 x : 询问第 x 个节点的值
Tips
树链剖分后的信息:
①剖分完成后,每条重链相当于一段连续区间
比如:1-3-6-10对应的新区间是1-2-3-4
②每棵子树也对应一个连续的区间
比如x=2这棵子树对应的新区间是7-8-9-10
题解
问题可以转化为给区间加二次函数
不妨设某段区间增加(u-nid[x])2或(nid[y]-u)2
lx表示左数第几个点,lx=u-nid[x]
ry表示右数第几个点,ry=nid[y]-u
展开之后是1×nid2-2u×nid+u2×1
容易发现对三个标记分开维护就可以了
Example1
从9到7的链=(9,5,2,1,3,7)
操作后nid9+=1,nid5+=4,nid2+=9,nid1+=16,nid3+=25,nid7+=36
树剖后的对应区间:
①[5,5],y从7跳到3,分别+62
②[7,9],x从9跳到1,分别+32,+22,+12
③[1,2],分别+42,+52
Example2
从7到9的链=(7,3,1,2,5,9)
操作后nid7+=1,nid3+=4,nid1+=9,nid2+=16,nid5+=25,nid9+=36
树剖后的对应区间:
①[5,5],x从7跳到3,分别+12
②[7,9],y从9跳到1,分别+42,+52,+62
③[1,2],分别+32,+22
Example3
从4到8的链=(4,2,1,3,8)
操作后nid4+=1,nid2+=4,nid1+=9,nid3+=16,nid8+=25
树剖后的对应区间:
①[6,6],y从8跳到3,分别+52
②[10,10],x从4跳到2,分别+12
③[7,7],x从2跳到1,分别+22
④[1,2],分别+32,+42
Example4
从8到4的链=(8,3,1,2,4)
操作后nid8+=1,nid3+=4,nid1+=9,nid2+=16,nid4+=25
树剖后的对应区间:
①[10,10],y从4跳到2,分别+52
②[6,6],x从8跳到3,分别+12
③[7,7],y从2跳到1,分别+42
④[1,2],分别+32,+22
代码
线段树
#include<bits/stdc++.h>
using namespace std;
const int maxn=100010;
int n,q,o,x,y,tot,head[maxn],size[maxn],d[maxn],son[maxn],f[maxn],top[maxn],nid[maxn],oid[maxn],dfn;
long long u;
struct acm
{
int y,next;
}
a[maxn*2];
struct SegmentTree
{
int l,r;
long long sum1,add1,sum2,add2,sum3,add3;
#define l(x) tree[x].l
#define r(x) tree[x].r
#define sum1(x) tree[x].sum1
#define add1(x) tree[x].add1
#define sum2(x) tree[x].sum2
#define add2(x) tree[x].add2
#define sum3(x) tree[x].sum3
#define add3(x) tree[x].add3
}
tree[maxn*4];
inline int read()
{
int num=0,flag=1;
char c=getchar();
for (;c<'0'||c>'9';c=getchar())
if (c=='-') flag=-1;
for (;c>='0'&&c<='9';c=getchar())
num=(num<<3)+(num<<1)+c-48;
return num*flag;
}
void addd(int x,int y)
{
a[++tot].y=y;
a[tot].next=head[x];
head[x]=tot;
}
void dfs1(int x,int fath)
{
size[x]=1;
d[x]=d[fath]+1;
son[x]=0;
f[x]=fath;
for (int i=head[x];i;i=a[i].next)
{
int y=a[i].y;
if (y==fath) continue;
dfs1(y,x);
size[x]+=size[y];
if (size[son[x]]<size[y]) son[x]=y;
}
}
void dfs2(int x,int topx)
{
top[x]=topx;
nid[x]=++dfn;
oid[dfn]=x;
if (son[x]!=0) dfs2(son[x],topx);
for (int i=head[x];i;i=a[i].next)
{
int y=a[i].y;
if (y!=f[x]&&y!=son[x]) dfs2(y,y);
}
}
void spread(int p)
{
if (add1(p))
{
sum1(p*2)+=add1(p)*(r(p*2)-l(p*2)+1);
sum1(p*2+1)+=add1(p)*(r(p*2+1)-l(p*2+1)+1);
add1(p*2)+=add1(p);
add1(p*2+1)+=add1(p);
add1(p)=0;
}
if (add2(p))
{
sum2(p*2)+=add2(p)*(r(p*2)-l(p*2)+1);
sum2(p*2+1)+=add2(p)*(r(p*2+1)-l(p*2+1)+1);
add2(p*2)+=add2(p);
add2(p*2+1)+=add2(p);
add2(p)=0;
}
if (add3(p))
{
sum3(p*2)+=add3(p)*(r(p*2)-l(p*2)+1);
sum3(p*2+1)+=add3(p)*(r(p*2+1)-l(p*2+1)+1);
add3(p*2)+=add3(p);
add3(p*2+1)+=add3(p);
add3(p)=0;
}
}
void build(int p,int l,int r)
{
l(p)=l;
r(p)=r;
if (l==r) return;
int mid=(l+r)/2;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
}
void change(int p,int l,int r,long long d1,long long d2,long long d3)
{
if (l<=l(p)&&r>=r(p))
{
sum1(p)+=d1*(r(p)-l(p)+1);
add1(p)+=d1;
sum2(p)+=d2*(r(p)-l(p)+1);
add2(p)+=d2;
sum3(p)+=d3*(r(p)-l(p)+1);
add3(p)+=d3;
return;
}
spread(p);
int mid=(l(p)+r(p))/2;
if (l<=mid) change(p*2,l,r,d1,d2,d3);
if (r>mid) change(p*2+1,l,r,d1,d2,d3);
sum1(p)=sum1(p*2)+sum1(p*2+1);
sum2(p)=sum2(p*2)+sum2(p*2+1);
sum3(p)=sum3(p*2)+sum3(p*2+1);
}
long long ask(int p,int x)
{
if (l(p)==r(p)) return sum1(p)*x*x+sum2(p)*x+sum3(p);
spread(p);
int mid=(l(p)+r(p))/2;
if (x<=mid) return ask(p*2,x);
else return ask(p*2+1,x);
}
int lca(int x,int y)
{
while (top[x]!=top[y])
{
if (d[top[x]]<d[top[y]]) swap(x,y);
x=f[top[x]];
}
return d[x]<d[y]?x:y;
}
void chain(int x,int y)
{
int lx=1,ry=d[x]+d[y]-2*d[lca(x,y)]+1;
while (top[x]!=top[y])
{
if (d[top[x]]>d[top[y]])
{
u=nid[x]+lx;
change(1,nid[top[x]],nid[x],1,-2*u,u*u);
lx+=nid[x]-nid[top[x]]+1;
x=f[top[x]];
}
else
{
u=nid[y]-ry;
change(1,nid[top[y]],nid[y],1,-2*u,u*u);
ry-=nid[y]-nid[top[y]]+1;
y=f[top[y]];
}
}
if (d[x]<d[y])
{
u=nid[y]-ry;
change(1,nid[x],nid[y],1,-2*u,u*u);
}
else
{
u=nid[x]+lx;
change(1,nid[y],nid[x],1,-2*u,u*u);
}
}
int main()
{
//freopen(".in","r",stdin);
//freopen(".out","w",stdout);
n=read();
for (int i=1;i<n;++i)
{
x=read();
y=read();
addd(x,y);
addd(y,x);
}
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
q=read();
for (int i=1;i<=q;++i)
{
o=read();
if (o==1)
{
x=read();
y=read();
chain(x,y);
}
if (o==2)
{
x=read();
printf("%lld\n",ask(1,nid[x]));
}
}
fclose(stdin);
fclose(stdout);
return 0;
}
树状数组
#include<bits/stdc++.h>
using namespace std;
const int maxn=100010;
int n,q,o,x,y,tot,head[maxn],size[maxn],d[maxn],son[maxn],f[maxn],top[maxn],nid[maxn],oid[maxn],dfn;
long long u,c[5][maxn];
struct acm
{
int y,next;
}
a[maxn*2];
inline int read()
{
int num=0,flag=1;
char c=getchar();
for (;c<'0'||c>'9';c=getchar())
if (c=='-') flag=-1;
for (;c>='0'&&c<='9';c=getchar())
num=(num<<3)+(num<<1)+c-48;
return num*flag;
}
void addd(int x,int y)
{
a[++tot].y=y;
a[tot].next=head[x];
head[x]=tot;
}
void dfs1(int x,int fath)
{
size[x]=1;
d[x]=d[fath]+1;
son[x]=0;
f[x]=fath;
for (int i=head[x];i;i=a[i].next)
{
int y=a[i].y;
if (y==fath) continue;
dfs1(y,x);
size[x]+=size[y];
if (size[son[x]]<size[y]) son[x]=y;
}
}
void dfs2(int x,int topx)
{
top[x]=topx;
nid[x]=++dfn;
oid[dfn]=x;
if (son[x]!=0) dfs2(son[x],topx);
for (int i=head[x];i;i=a[i].next)
{
int y=a[i].y;
if (y!=f[x]&&y!=son[x]) dfs2(y,y);
}
}
int lowbit(int x)
{
return x&-x;
}
void add(int x,long long y,int k)
{
for (;x<=n;x+=lowbit(x))
c[k][x]+=y;
}
long long ask(int x,int k)
{
long long ans=0;
for (;x;x-=lowbit(x))
ans+=c[k][x];
return ans;
}
int lca(int x,int y)
{
while (top[x]!=top[y])
{
if (d[top[x]]<d[top[y]]) swap(x,y);
x=f[top[x]];
}
return d[x]<d[y]?x:y;
}
void chain(int x,int y)
{
int lx=1,ry=d[x]+d[y]-2*d[lca(x,y)]+1;
while (top[x]!=top[y])
{
if (d[top[x]]>d[top[y]])
{
u=nid[x]+lx;
add(nid[top[x]],1,1);
add(nid[x]+1,-1,1);
add(nid[top[x]],-2*u,2);
add(nid[x]+1,2*u,2);
add(nid[top[x]],u*u,3);
add(nid[x]+1,-u*u,3);
lx+=nid[x]-nid[top[x]]+1;
x=f[top[x]];
}
else
{
u=nid[y]-ry;
add(nid[top[y]],1,1);
add(nid[y]+1,-1,1);
add(nid[top[y]],-2*u,2);
add(nid[y]+1,2*u,2);
add(nid[top[y]],u*u,3);
add(nid[y]+1,-u*u,3);
ry-=nid[y]-nid[top[y]]+1;
y=f[top[y]];
}
}
if (d[x]<d[y])
{
u=nid[y]-ry;
add(nid[x],1,1);
add(nid[y]+1,-1,1);
add(nid[x],-2*u,2);
add(nid[y]+1,2*u,2);
add(nid[x],u*u,3);
add(nid[y]+1,-u*u,3);
}
else
{
u=nid[x]+lx;
add(nid[y],1,1);
add(nid[x]+1,-1,1);
add(nid[y],-2*u,2);
add(nid[x]+1,2*u,2);
add(nid[y],u*u,3);
add(nid[x]+1,-u*u,3);
}
}
int main()
{
//freopen(".in","r",stdin);
//freopen(".out","w",stdout);
n=read();
for (int i=1;i<n;++i)
{
x=read();
y=read();
addd(x,y);
addd(y,x);
}
dfs1(1,0);
dfs2(1,1);
q=read();
for (int i=1;i<=q;++i)
{
o=read();
if (o==1)
{
x=read();
y=read();
chain(x,y);
}
if (o==2)
{
x=read();
printf("%lld\n",ask(nid[x],1)*nid[x]*nid[x]+ask(nid[x],2)*x+ask(nid[x],3));
}
}
fclose(stdin);
fclose(stdout);
return 0;
}