题目:
题解:
带修改的树上莫队裸题咯。在算法上就是BZOJ2120和BZOJ3757的结合啊
注意这里分块的标准是:左端块,右端顶点,记录的最近一次修改的编号
这里判断要修改的点在不在目前所求路径上面可以用change操作的vis
代码:
#include <cmath>
#include <cstdio>
#include <iostream>
#include <algorithm>
#define LL long long
using namespace std;
const int sz=24;
const int N=100005;
struct hh{int x,y,id,t;}ask[N],ch[N];
int n,m,q,tot,block,top,point[N],nxt[N*2],v[N*2],c[N];
int dfn[N],h[N],nn,f[N][sz],mi[sz],pos[N],cnt,stack[N],num[N];
LL ans[N],a[N],w[N],sum;
bool vis[N];
int cmp(hh a,hh b)
{
return pos[a.x]<pos[b.x] || (pos[a.x]==pos[b.x] && dfn[a.y]<dfn[b.y]) ||
(pos[a.x]==pos[b.x] && dfn[a.y]==dfn[b.y] && a.t<b.t);
}
void addline(int x,int y)
{
++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
++tot; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void dfs(int x,int fa)
{
dfn[x]=++nn; h[x]=h[fa]+1; int bottom=top;
for (int i=1;i<sz;i++)
if (h[x]<mi[i]) break;
else f[x][i]=f[f[x][i-1]][i-1];
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa)
{
f[v[i]][0]=x;dfs(v[i],x);
if (top-bottom>=block)
{
cnt++;
while (top!=bottom) pos[stack[top--]]=cnt;
}
}
stack[++top]=x;
}
int lca(int x,int y)
{
if (h[x]<h[y]) swap(x,y);
int k=h[x]-h[y];
for (int i=0;i<sz;i++)
if (k&(1<<i)) x=f[x][i];
if (x==y) return x;
for (int i=sz-1;i>=0;i--)
if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
void change(int x)
{
if (vis[x]) vis[x]=0,sum-=w[num[c[x]]--]*a[c[x]];
else vis[x]=1,sum+=w[++num[c[x]]]*a[c[x]];
}
void reverse(int x,int y)
{
while (x!=y)
if (h[x]<h[y]) change(y),y=f[y][0];
else change(x),x=f[x][0];
}
void modi(int now)
{
if (vis[ch[now].x])
{
change(ch[now].x);
swap(c[ch[now].x],ch[now].y);
change(ch[now].x);
}else swap(c[ch[now].x],ch[now].y);
}
int main()
{
int x,y,tim=0,wh=0;scanf("%d%d%d",&n,&m,&q);
for (int i=1;i<=m;i++) scanf("%d",&a[i]);
for (int i=1;i<=n;i++) scanf("%d",&w[i]);
for (int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
addline(x,y);
}
mi[0]=1;for (int i=1;i<sz;i++) mi[i]=mi[i-1]*2;
block=sqrt(n); dfs(1,0);
cnt++; while (top) pos[stack[top--]]=cnt;
for (int i=1;i<=n;i++) scanf("%d",&c[i]);
for (int i=1;i<=q;i++)
{
int id;scanf("%d%d%d",&id,&x,&y);
if (id==0) ch[++tim].x=x,ch[tim].y=y;
else
{
if (dfn[x]>dfn[y]) swap(x,y);
ask[++wh].x=x,ask[wh].y=y,ask[wh].t=tim,ask[wh].id=wh;
}
}
sort(ask+1,ask+wh+1,cmp);
int t=lca(ask[1].x,ask[1].y),now=0;
while (now<ask[1].t) modi(++now);
reverse(ask[1].x,ask[1].y); change(t);
ans[ask[1].id]=sum;
for (int i=2;i<=wh;i++)
{
change(t); reverse(ask[i-1].x,ask[i].x); reverse(ask[i-1].y,ask[i].y);
while (now<ask[i].t) modi(++now);
while (now>ask[i].t) modi(now--);
t=lca(ask[i].x,ask[i].y); change(t);
ans[ask[i].id]=sum;
}
for (int i=1;i<=wh;i++) printf("%lld\n",ans[i]);
}