这题可以用LCT和树链剖分写(我写的第一个树链剖分),都是很难很长的东西,我用的树链剖分写。树链剖分就是将一棵树划分成若干条链,然后用线段树或者其他数据结构维护链上的信息,算法分部分很多,但是思路还算清晰,具体算法参考Qtree的一些解法和QZQ09年的论文。这题更新最大值的初始化卡了好久,不能是0,要是-maxlongint。
超丑的原始代码。。
#include<iostream>
#include<cstdio>
#include<malloc.h>
#include<memory.h>
using namespace std;
const int MAX=999999999,maxn=30010;
const char C[]="CHANGE",QM[]="QMAX",QS[]="QSUM";
struct point{
int e;
point *next;
};//邻接表
struct segtree{
int l,r,sum,maxn,lc,rc;
};//线段树节点
point *a[maxn]={0},*pw[maxn]={0},p;
segtree st[10*maxn];
char str[10];
int j,root,i,n,ans,s,e,q,k,c,size[maxn],belong[maxn],fa[maxn],xu[maxn],ptop[maxn],proot[maxn],psize[maxn],pd[maxn],w[maxn],sump,sumst,u[maxn],wst[maxn];
//找重心
void addedge(int s,int e)
{point *p;
p=(point*)malloc(sizeof(point));
p->e=e;
p->next=a[s];
a[s]=p;
}
void add(int k,int pp)
{point *p;
p=(point*)malloc(sizeof(point));
p->e=w[k];
p->next=pw[pp];
pw[pp]=p;
}
void dfs(int k)
{u[k]=1;
size[k]=1;
point *p;
p=(point*)malloc(sizeof(point));
p=a[k];
while (p)
{if (u[p->e]==0) dfs(p->e);
size[k]+=size[p->e];
p=p->next;
}
}
void findroot()
{int maxs=0,mins=MAX,i;
point *p;
p=(point*)malloc(sizeof(point));
for (i=1;i<=n;i++)
{p=a[i];
maxs=n-size[i];
while (p)
{maxs=max(maxs,size[p->e]);
p=p->next;
}
if (maxs<mins) {mins=maxs;root=i;}
}
}
//建树记录数据
void dfs1(int k,int d)
{u[k]=1;size[k]=1;
point *p;
p=(point*)malloc(sizeof(point));
int maxs=0,j;
p=a[k];
while (p)
{if (u[p->e]==0)
{fa[p->e]=k;
dfs1(p->e,d+1);
size[k]+=size[p->e];
if (size[p->e]>maxs) {maxs=size[p->e];j=p->e;}
}
p=p->next;
}
p=a[k];
belong[k]=0;
while (p)
{if (p->e!=fa[k])
{if (p->e==j)
{xu[k]=xu[p->e]+1;
belong[k]=belong[p->e];
add(k,belong[k]);
}
else
{int pnow=belong[p->e];
ptop[pnow]=p->e;
psize[pnow]=xu[p->e];
pd[pnow]=d+1;
}
}
p=p->next;
}
if (belong[k]==0)
{belong[k]=++sump;
xu[k]=1;
add(k,sump);
}
}
//线段树
void build(int k,int l,int r)
{st[k].l=l;st[k].r=r;
if (l==r)
{st[k].sum=wst[l];
st[k].maxn=wst[l];
st[k].lc=0;st[k].rc=0;
return;
}
int mid=(l+r)/2;
st[k].lc=++sumst;
build(sumst,l,mid);
st[k].rc=++sumst;
build(sumst,mid+1,r);
st[k].sum=st[st[k].lc].sum+st[st[k].rc].sum;
st[k].maxn=max(st[st[k].lc].maxn,st[st[k].rc].maxn);
}
void insert(int k,int u,int c)
{if (u<st[k].l||u>st[k].r) return;
if (st[k].l==u&&u==st[k].r)
{st[k].sum=c;
st[k].maxn=c;
return;
}
insert(st[k].lc,u,c);
insert(st[k].rc,u,c);
st[k].maxn=max(st[st[k].lc].maxn,st[st[k].rc].maxn);
st[k].sum=st[st[k].lc].sum+st[st[k].rc].sum;
}
int query(int k,int l,int r,int kind)
{if (l>st[k].r||r<st[k].l)
{if (kind==0) return -MAX;else return 0;
}
if (l<=st[k].l&&r>=st[k].r)
if (kind==0) return st[k].maxn; else return st[k].sum;
if (kind==0) return max(query(st[k].lc,l,r,kind),query(st[k].rc,l,r,kind)); else return query(st[k].lc,l,r,kind)+query(st[k].rc,l,r,kind);
}
//预处理
void prepare()
{memset(u,0,sizeof(u));
memset(size,0,sizeof(size));
dfs(1);
findroot();
memset(u,0,sizeof(u));
memset(size,0,sizeof(size));
sump=fa[root]=0;
sump=0;
dfs1(root,0);
int pnow=belong[root],j;
ptop[pnow]=root;
pd[pnow]=0;
psize[pnow]=xu[root];
sumst=0;
point *p;
p=(point*)malloc(sizeof(point));
for (i=1;i<=sump;i++)
{p=pw[i];
for (j=psize[i];j>=1;j--)
{wst[j]=p->e;
p=p->next;
}
/*for (j=1;j<=psize[i];j++)
printf("%d ",wst[j]);
cout<<endl;*/
proot[i]=++sumst;
build(sumst,1,psize[i]);
}
}
void work(int a,int b,int kind)
{//printf("%d %d %d\n",a,b,kind);
int ba=belong[a],bb=belong[b],da,db,ans=0;
bool g=true;
if (kind==0) ans=-MAX;
if (a==b)
{if (kind==0) ans=max(w[a],w[b]); else ans=w[a];
g=false;
}
//printf("%d %d\n",xu[a],xu[b]);
while (ba!=bb)
{da=pd[ba];db=pd[bb];
if (da>=db)
{if (kind==0)
ans=max(ans,query(proot[ba],xu[a],psize[ba],0));
else
ans+=query(proot[ba],xu[a],psize[ba],1);
a=fa[ptop[ba]];
ba=belong[a];
}
else
{if (kind==0)
ans=max(ans,query(proot[bb],xu[b],psize[bb],0));
else
ans+=query(proot[bb],xu[b],psize[bb],1);
b=fa[ptop[bb]];
bb=belong[b];
}
}
if (b==a&&g)
if (kind==0) ans=max(ans,w[b]); else ans=ans+w[a];
if (b==a) printf("%d\n",ans);
else
{if (kind==0) ans=max(ans,query(proot[ba],min(xu[a],xu[b]),max(xu[a],xu[b]),0));
else ans+=query(proot[ba],min(xu[a],xu[b]),max(xu[a],xu[b]),1);
printf("%d\n",ans);
}
}
void print(int k)
{printf("%d %d %d %d\n",st[k].l,st[k].r,st[k].maxn,st[k].sum);
if (st[k].l==st[k].r) return;
print(st[k].lc);print(st[k].rc);
}
int main()
{freopen("count1.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d",&n);
for (i=1;i<=n-1;i++)
{scanf("%d%d",&s,&e);
addedge(s,e);
addedge(e,s);
}
for (i=1;i<=n;i++)
scanf("%d",&w[i]);
prepare();
/*for (i=1;i<=n;i++)
printf("%d %d\n",belong[i],xu[i]);
for (j=1;j<=sump;j++)
print(proot[j]);*/
scanf("%d",&q);
for (i=1;i<=q;i++)
{scanf("%s%d%d",str,&k,&c);
if (str[0]=='C')
{insert(proot[belong[k]],xu[k],c);
w[k]=c;
continue;
}
if (str[1]=='M') work(k,c,0);
if (str[1]=='S') work(k,c,1);
}
fclose(stdin);
fclose(stdout);
}