http://uoj.ac/problem/58 (题目链接)
题意:给定一棵树,每个点有一个颜色,提供两种操作:
1.询问两点间路径上的Σv[a[i]]*w[k],其中a[i]代表这个点的颜色,k表示这个点是这种颜色第k次出现
2.修改某个点的颜色
Solution
带修改树上莫队。
按左端点所在块为第一关键字,右端点所在块为第二关键字,时间为第三关键字,排序。可能会有疑问可不可以以右端点dfs序为第二关键字?这里我们为了突出第三关键字的作用,选择以右端点所在块为第二关键字。每个节点的dfs序都不同,如果以dfs序为第二关键字的话,第三关键字就没用了。当然这样写也不是不行,但时间会略长。
然后进行树上莫队,每次询问经过修改或逆修改来使时间倒流或前进。
复杂度证明,很好理解。
复杂度证明:
设block_num为块数,block_size为块的大小,则有block_num×block_size=n,在证明中我们假设n,q同阶。
设块对(block_i,block_j),易知这样的块对不会超过block_num2个。
对于块对内的操作:我们考虑总复杂度,左端点共移动至多O(q×block_size),右端点亦是。时间共移动至多O(block_num2×q)。故这一部分的复杂度为O(n×(block_size+block_num2))。
对于块与块之间的操作,不超过block_num2次:左端第移动一次,最多O(n),右端点亦是如此。时间最多移动O(q)=O(n)。故这一部分复杂度为O(block_num2×n)。
故总复杂度为O(n×(block_size+block_num2))。
可以证明当block_size=n2/3时,block_num=n1/3,复杂度最优,为O(n5/3)。
代码:
// uoj58
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#define MOD 1000000007
#define inf 2147483640
#define LL long long
#define free(a) freopen(a".in","r",stdin);freopen(a".out","w",stdout);
using namespace std;
inline LL getint() {
LL x=0,f=1;char ch=getchar();
while (ch>'9' || ch<'0') {if (ch=='-') f=-1;ch=getchar();}
while (ch>='0' && ch<='9') {x=x*10+ch-'0';ch=getchar();}
return x*f;
}
const int maxn=100010;
struct edge {int to,next;}e[maxn<<2];
struct ask {int u,v,id,pre,t;}a1[maxn],a2[maxn];
LL res[maxn],ans;
int pos[maxn],v[maxn],w[maxn],dfn[maxn],bin[20],fa[maxn][20],deep[maxn],st[maxn],p[maxn],vis[maxn],c[maxn],pre[maxn],head[maxn];
int block,blonum,n,m,q,cnt,cnt1,cnt2,top;
void insert(int u,int v) {
e[++cnt].to=v;e[cnt].next=head[u];head[u]=cnt;
e[++cnt].to=u;e[cnt].next=head[v];head[v]=cnt;
}
bool cmp(ask a,ask b) {
if (pos[a.u]==pos[b.u] && pos[a.v]==pos[b.v]) return a.t<b.t;
if (pos[a.u]==pos[b.u]) return pos[a.v]<pos[b.v];
return pos[a.u]<pos[b.u];
}
int dfs(int x) {
int size=0;
dfn[x]=++cnt;
for (int i=1;i<20;i++) fa[x][i]=fa[fa[x][i-1]][i-1];
for (int i=head[x];i;i=e[i].next) if (e[i].to!=fa[x][0]) {
deep[e[i].to]=deep[x]+1;
fa[e[i].to][0]=x;
size+=dfs(e[i].to);
if (size>=block) {
blonum++;
while (size--) pos[st[top--]]=blonum;
size=0;
}
}
st[++top]=x;
return size+1;
}
void work(int x) {
if (!vis[x]) {vis[x]=1;p[c[x]]++;ans+=(LL)w[p[c[x]]]*v[c[x]];}
else {vis[x]=0;ans-=(LL)w[p[c[x]]]*v[c[x]];p[c[x]]--;}
}
void modify(int x,int y) {
if (vis[x]) {
work(x);
c[x]=y;
work(x);
}
else c[x]=y;
}
void solve(int x,int y) {
while (x!=y) {
if (deep[x]<deep[y]) work(y),y=fa[y][0];
else work(x),x=fa[x][0];
}
}
int lca(int x,int y) {
if (deep[x]<deep[y]) swap(x,y);
int t=deep[x]-deep[y];
for (int i=0;bin[i]<=t;i++) if (bin[i]&t) x=fa[x][i];
for (int i=19;i>=0;i--)
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return x==y?x:fa[x][0];
}
int main() {
bin[0]=1;for (int i=1;i<20;i++) bin[i]=bin[i-1]<<1;
scanf("%d%d%d",&n,&m,&q);
for (int i=1;i<=m;i++) scanf("%d",&v[i]);
for (int i=1;i<=n;i++) scanf("%d",&w[i]);
for (int i=1;i<n;i++) {
int u,v;
scanf("%d%d",&u,&v);
insert(u,v);
}
for (int i=1;i<=n;i++) scanf("%d",&c[i]),pre[i]=c[i];
block=(int)pow(n,0.6);
cnt=0;dfs(1);
cnt1=0,cnt2=0;
for (int i=1;i<=q;i++) {
int x,u,v;
scanf("%d%d%d",&x,&u,&v);
if (x) {
cnt1++;
if (dfn[u]>dfn[v]) swap(u,v);
a1[cnt1].u=u;a1[cnt1].v=v;a1[cnt1].id=cnt1;a1[cnt1].t=cnt2;
}
else {
cnt2++;
a2[cnt2].u=u;a2[cnt2].v=v;a2[cnt2].pre=pre[u];pre[u]=v;
}
}
sort(a1+1,a1+cnt1+1,cmp);
for (int i=1;i<=a1[1].t;i++) modify(a2[i].u,a2[i].v);
solve(a1[1].u,a1[1].v);
int t=lca(a1[1].u,a1[1].v);
work(t);
res[a1[1].id]=ans;
work(t);
for (int i=2;i<=cnt1;i++) {
for (int j=a1[i-1].t+1;j<=a1[i].t;j++) modify(a2[j].u,a2[j].v);
for (int j=a1[i-1].t;j>a1[i].t;j--) modify(a2[j].u,a2[j].pre);
solve(a1[i-1].u,a1[i].u);
solve(a1[i-1].v,a1[i].v);
t=lca(a1[i].u,a1[i].v);
work(t);
res[a1[i].id]=ans;
work(t);
}
for (int i=1;i<=cnt1;i++) printf("%lld\n",res[i]);
return 0;
}