点击这里查看原题
此题乃莫队系列问题的集大成者,既需要树上莫队,又需要修改,注意以下问题:
- 因为带修改,块大小为n^(2/3)
- 修改时必须严格按时间顺序,对于当前时间大于询问时间的,时间必须倒着遍历;小于的,时间必须正着遍历(没注意到这个问题所以WA了好几次)
/*
User:Small
Language:C++
Problem No.:3052
*/
#include<bits/stdc++.h>
#define ll long long
#define inf 999999999
using namespace std;
const int M=1e5+5;
int n,m,q,t,tot,fir[M],dep[M],c[M],ctmp[M],pos[M],dfn[M],dfs_clock,num[M],lg[M],anc[M][20],acnt,bcnt,stk[M],tp;
ll res,ans[M],v[M],w[M];
bool vis[M];
struct edge{
int v,nex;
}e[M<<1];
struct xg{
int p,x,y;
}a[M];
struct cx{
int u,v,pre,id;
bool operator<(const cx &b)const{
if(pos[u]!=pos[b.u]) return pos[u]<pos[b.u];
if(pos[v]!=pos[b.v]) return pos[v]<pos[b.v];
if(pre!=b.pre) return pre<b.pre;
return dfn[v]<dfn[b.v];
}
}b[M];
void add(int u,int v){
e[++tot]=(edge){v,fir[u]};
fir[u]=tot;
}
int dfs(int u){
int siz=0;
dfn[u]=++dfs_clock;
for(int i=fir[u];i;i=e[i].nex){
int v=e[i].v;
if(anc[u][0]==v) continue;
dep[v]=dep[u]+1;
anc[v][0]=u;
siz+=dfs(v);
if(siz>=t){
tot++;
for(int i=1;i<=siz;i++) pos[stk[tp--]]=tot;
siz=0;
}
}
stk[++tp]=u;
return siz+1;
}
int LCA(int u,int v){
if(dep[u]<dep[v]) swap(u,v);
int d=dep[u]-dep[v];
for(int i=lg[d];i>=0;i--)
if(d&(1<<i)) u=anc[u][i];
if(u==v) return u;
for(int i=lg[n];i>=0;i--)
if(anc[u][i]!=anc[v][i]) u=anc[u][i],v=anc[v][i];
return anc[u][0];
}
void update(int x){
if(vis[x]){
res-=(ll)v[c[x]]*w[num[c[x]]];
num[c[x]]--;
}
else{
num[c[x]]++;
res+=(ll)v[c[x]]*w[num[c[x]]];
}
vis[x]^=1;
}
void change(int p,int x){
if(vis[p]){
update(p);
c[p]=x;
update(p);
}
else c[p]=x;
}
void work(int u,int v){
int lca=LCA(u,v);
while(u!=lca){
update(u);
u=anc[u][0];
}
while(v!=lca){
update(v);
v=anc[v][0];
}
}
void solve(){
int now=0;
for(int i=1;i<=bcnt;i++){
work(b[i-1].u,b[i].u);
work(b[i-1].v,b[i].v);
int lca=LCA(b[i].u,b[i].v);
update(lca);
for(int j=now;j>b[i].pre;j--) change(a[j].p,a[j].x);
for(int j=now+1;j<=b[i].pre;j++) change(a[j].p,a[j].y);//不能写成 for(int j=b[i].pre;j>now;j--)
ans[b[i].id]=res;
update(lca);
now=b[i].pre;
}
}
int main(){
freopen("data.in","r",stdin);//
scanf("%d%d%d",&n,&m,&q);
t=pow(n,2.0/3);
for(int i=1;i<=m;i++) scanf("%lld",&v[i]);
for(int i=1;i<=n;i++) scanf("%lld",&w[i]);
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
for(int i=1;i<=n;i++) scanf("%d",&c[i]);
memcpy(ctmp,c,sizeof(c));
for(int i=1;i<=q;i++){
int typ,u,v;
scanf("%d%d%d",&typ,&u,&v);
if(typ)
b[++bcnt]=(cx){u,v,acnt,bcnt};
else{
a[++acnt]=(xg){u,ctmp[u],v};
ctmp[u]=v;
}
}
lg[0]=-1;
for(int i=1;i<=n;i++) lg[i]=lg[i>>1]+1;
tot=0;
int rm=dfs(1);
for(int i=1;i<=lg[n];i++)
for(int j=1;j<=n;j++) anc[j][i]=anc[anc[j][i-1]][i-1];
for(int i=1;i<=rm;i++) pos[stk[tp--]]=tot;
sort(b+1,b+bcnt+1);
b[0]=(cx){1,1,0,0};
solve();
for(int i=1;i<=bcnt;i++)
printf("%lld\n",ans[i]);
return 0;
}