题意:20w个节点的树,有边权。一个操作:把某个点的子树的边权都加一个给定的值。一个询问,问树上两点间边权的平方和。
题解:裸的树链剖分,很久没有先线段树了,lazy标记已经忘得差不多了,回忆一下,lazy记录的是当前区间未向子区间传递的信息,本区间的信息和本区间的lazy一点关系都没有。对于树边,我们把值赋到深度更大的那个点上,最后讨论一下边界就好了
本题要维护三个数组,区间平方和,区间所有数的和,lazy标记数组。
#include <bits/stdc++.h>
using namespace std;
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define INF 0x3f3f3f3f
#define maxn 810000
int id,w[maxn],siz[maxn],val[maxn],en,ha[maxn],dep[maxn];
int first[maxn],next2[maxn],to[maxn],top[maxn];
int fa[maxn],sum1[maxn];
int sum[maxn],n,m,son[maxn],plu[maxn],tal[maxn];
int tag[maxn];
int getint()
{
char c;int tmp,f=1;
while(c=getchar(),c<'0'||c>'9')
{
if(c=='-') f=-1;
}
tmp=c-'0';
while(c=getchar(),c>='0'&&c<='9')
tmp=tmp*10+c-'0';
return tmp*f;
}
void add(int a,int b,int c)
{
en++;
to[en]=b;
val[en]=c;
next2[en]=first[a];
first[a]=en;
}
void dfs(int now)
{
int v,maxv=0;siz[now]=1;
for(int i=first[now];i;i=next2[i])
{
v=to[i];
if(fa[now]==v) continue;
fa[v]=now;dep[v]=dep[now]+1;
dfs(v);
siz[now]+=siz[v];
if(maxv<siz[v])
{
maxv=siz[v];
son[now]=v;
}
}
}
void getid(int now,int root)
{
int v;id++;
w[now]=id;
top[now]=root;
if(son[now]) getid(son[now],root);
for(int i=first[now];i;i=next2[i])
{
v=to[i];
if(v==son[now]||v==fa[now]) continue;
getid(v,v);
}
}
void pushup(int rt)
{
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
tal[rt]=tal[rt<<1]+tal[rt<<1|1];
}
void pushdown(int rt,int l,int r)
{
if(tag[rt])
{
int mid=(l+r)/2;
int len1=(mid-l+1);
int len2=r-mid;
tag[rt<<1]+=tag[rt];
tag[rt<<1|1]+=tag[rt];
tal[rt<<1]=tal[rt<<1]+ tag[rt]*tag[rt]*(len1) + 2*sum[rt<<1]*tag[rt];
tal[rt<<1|1]=tal[rt<<1|1]+tag[rt]*tag[rt]*(len2)+2*sum[rt<<1|1]*tag[rt];
sum[rt<<1]+=tag[rt]*len1;
sum[rt<<1|1]+=tag[rt]*len2;
tag[rt]=0;
}
}
void update2(int l,int r,int rt,int x,int y,int v)
{
if(x>y) return;
if(x<=l&&r<=y)
{
tag[rt]+=v;
tal[rt]=tal[rt]+v*v*(r-l+1)+2*sum[rt]*v;
sum[rt]+=v*(r-l+1);
return;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
if(x<=mid) update2(lson,x,y,v);
if(y>mid) update2(rson,x,y,v);
pushup(rt);
}
void update(int l,int r,int rt,int pos,int v)
{
if(l==r)
{
tal[rt]=v*v;
sum[rt]=v;
tag[rt]=0;
return;
}
int mid=(l+r)>>1;
if(pos<=mid) update(lson,pos,v);
else update(rson,pos,v);
pushup(rt);
}
int getsum(int l,int r,int rt,int x,int y)
{
if(x<=l&&r<=y)
{
return tal[rt];
}
pushdown(rt,l,r);
int ans=0;
int mid=(l+r)>>1;
if(x<=mid) ans+=getsum(lson,x,y);
if(y>mid) ans+=getsum(rson,x,y);
return ans;
}
int getans(int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=getsum(1,n,1,w[top[x]],w[x]);
x=fa[top[x]];
}
if(x==y) return ans;
if(dep[x]>dep[y]) swap(x,y);
ans+=getsum(1,n,1,w[son[x]],w[y]);
return ans;
}
void build(int now)
{
int v;
for(int i=first[now];i;i=next2[i])
{
v=to[i];
if(v==fa[now]) continue;
ha[(i+1)/2]=v;
update(1,n,1,w[v],val[i]);
build(v);
}
}
int main()
{
int cas;
scanf("%d",&cas);
while(cas--)
{
memset(tag,0,sizeof(tag));
en=0;id=0;
memset(first,0,sizeof(first));
memset(son,0,sizeof(son));
int a,b,c,op;
scanf("%d",&n);
for(int i=1;i<n;i++)
{
a=getint();
b=getint();
c=getint();
add(a,b,c);
add(b,a,c);
}
dfs(1);
getid(1,1);
build(1);
scanf("%d",&m);
char st[10];
while(m--)
{
a=getint();
b=getint();
c=getint();
if(a==1)
{
update2(1,n,1,w[b]+1,w[b]+siz[b]-1,c);
}
else
{
printf("%d\n",getans(b,c));
}
}
}
return 0;
}
/*
2
5
1 2 1
2 3 1
3 4 1
4 5 1
3
1 3 5
2 1 3
2 3 5
5
1 2 1
2 3 1
3 4 1
4 5 1
3
1 3 5
2 1 3
2 3 5
*/