题目描述
Bob有一棵
n
n
个点的有根树,其中1号点是根节点。Bob在每个点上涂了颜色,并且每个点上的颜色不同。
定义一条路径的权值是:这条路径上的点(包括起点和终点)共有多少种不同的颜色。
Bob可能会进行这几种操作:
把点
x
x
到根节点的路径上所有的点染上一种没有用过的新颜色。
求
x
x
到 的路径的权值。
3 x
3
x
在以x为根的子树中选择一个点,使得这个点到根节点的路径权值最大,求最大权值。
Bob一共会进行
m
m
次操作
输入输出格式
输入格式:
第一行两个数 。
接下来 n−1 n − 1 行,每行两个数 a,b a , b ,表示 a a 与 之间有一条边。
接下来 m m 行,表示操作,格式见题目描述
输出格式:
每当出现2,3操作,输出一行。
如果是2操作,输出一个数表示路径的权值
如果是3操作,输出一个数表示权值的最大值
输入输出样例
输入样例#1:
5 6
1 2
2 3
3 4
3 5
2 4 5
3 3
1 4
2 4 5
1 5
2 4 5
输出样例#1:
3
4
2
2
说明
共10个测试点
测试点1,
测试点2、3,没有2操作
测试点4、5,没有3操作
测试点6,树的生成方式是,对于 i(2≤i≤n) i ( 2 ≤ i ≤ n ) ,在1到 i−1 i − 1 中随机选一个点作为i的父节点。
测试点7, 1≤n,m≤50000 1 ≤ n , m ≤ 50000
测试点8, 1≤n≤50000 1 ≤ n ≤ 50000
测试点9,10,无特殊限制
对所有数据, 1≤n≤105 1 ≤ n ≤ 10 5 , 1≤m≤105 1 ≤ m ≤ 10 5
分析:
显然每种颜色都是一条链,所以考虑
lct
l
c
t
。
一个点到根路径上的权值就是经过的splay数量,记为
f[x]
f
[
x
]
。
第一个操作就是
access
a
c
c
e
s
s
操作。
因为每条从
x
x
链到根的链的颜色包含他父亲的链的颜色,而且颜色都是连续的一段。
所以,到
y
y
的权值等于。可以分
lca
l
c
a
是否是一条链的起点讨论一下,发现都是对的。
另一个就是子树
f
f
的了。
我们可以使用dfs序,然后用线段树维护子树。对于
access
a
c
c
e
s
s
操作,连上去的这条链所在子树的splay数量减一,而原来的链的子树的splay数加一,维护线段树即可。注意,我们加的是原树上的子树,不能直接对应splay树上子树,所以我们要在这颗splay上往左走,终点就是原子树的根。
好像代码没改过,交上去第一次0分,第二次AC,有点诡异。
代码:
#include <iostream>
#include <cstdio>
#include <cmath>
const int maxn=1e5+7;
using namespace std;
int n,m,x,y,cnt,op;
int ls[maxn],dfn[maxn][2],f[maxn][20],dep[maxn];
struct node{
int l,r,fa;
int rev;
}t[maxn];
struct tree{
int lazy,maxx;
}a[maxn*4];
struct edge{
int y,next;
}g[maxn*2];
void add(int x,int y)
{
g[++cnt]=(edge){y,ls[x]};
ls[x]=cnt;
}
void clean(int p,int l,int r)
{
if (a[p].lazy)
{
a[p*2].lazy+=a[p].lazy;
a[p*2].maxx+=a[p].lazy;
a[p*2+1].lazy+=a[p].lazy;
a[p*2+1].maxx+=a[p].lazy;
a[p].lazy=0;
}
}
void ins(int p,int l,int r,int x,int y,int k)
{
if ((l==x) && (r==y))
{
a[p].lazy+=k;
a[p].maxx+=k;
return;
}
int mid=(l+r)/2;
clean(p,l,r);
if (y<=mid) ins(p*2,l,mid,x,y,k);
else if (x>mid) ins(p*2+1,mid+1,r,x,y,k);
else
{
ins(p*2,l,mid,x,mid,k);
ins(p*2+1,mid+1,r,mid+1,y,k);
}
a[p].maxx=max(a[p*2].maxx,a[p*2+1].maxx);
}
int getmax(int p,int l,int r,int x,int y)
{
if ((l==x) && (r==y)) return a[p].maxx;
int mid=(l+r)/2;
clean(p,l,r);
if (y<=mid) return getmax(p*2,l,mid,x,y);
else if (x>mid) return getmax(p*2+1,mid+1,r,x,y);
else return max(getmax(p*2,l,mid,x,mid),getmax(p*2+1,mid+1,r,mid+1,y));
}
bool isroot(int x)
{
return (x!=t[t[x].fa].l) && (x!=t[t[x].fa].r);
}
void remove(int x)
{
if (!isroot(x)) remove(t[x].fa);
if (t[x].rev)
{
t[x].rev^=1;
swap(t[x].l,t[x].r);
if (t[x].l) t[t[x].l].rev^=1;
if (t[x].r) t[t[x].r].rev^=1;
}
}
void rttr(int x)
{
int y=t[x].l;
t[x].l=t[y].r;
if (t[y].r) t[t[y].r].fa=x;
if (x==t[t[x].fa].l) t[t[x].fa].l=y;
else if (x==t[t[x].fa].r) t[t[x].fa].r=y;
t[y].fa=t[x].fa;
t[x].fa=y;
t[y].r=x;
}
void rttl(int x)
{
int y=t[x].r;
t[x].r=t[y].l;
if (t[y].l) t[t[y].l].fa=x;
if (x==t[t[x].fa].l) t[t[x].fa].l=y;
else if (x==t[t[x].fa].r) t[t[x].fa].r=y;
t[y].fa=t[x].fa;
t[x].fa=y;
t[y].l=x;
}
void splay(int x)
{
remove(x);
while (!isroot(x))
{
int p=t[x].fa,g=t[p].fa;
if (isroot(p))
{
if (x==t[p].l) rttr(p);
else rttl(p);
}
else
{
if (x==t[p].l)
{
if (p==t[g].l) rttr(p),rttr(g);
else rttr(p),rttl(g);
}
else
{
if (p==t[g].l) rttl(p),rttr(g);
else rttl(p),rttl(g);
}
}
}
}
int findroot(int x)
{
if (t[x].l) return findroot(t[x].l);
else return x;
}
void access(int x)
{
int y=0,d;
while (x)
{
splay(x);
if (t[x].r) d=findroot(t[x].r),ins(1,1,n,dfn[d][0],dfn[d][1],1);
t[x].r=y;
if (t[x].r) d=findroot(t[x].r),ins(1,1,n,dfn[d][0],dfn[d][1],-1);
y=x; x=t[x].fa;
}
}
void dfs(int x,int fa)
{
dep[x]=dep[fa]+1;
f[x][0]=fa;
dfn[x][0]=++cnt;
for (int i=ls[x];i>0;i=g[i].next)
{
int y=g[i].y;
if (y==fa) continue;
dfs(y,x);
t[y].fa=x;
}
dfn[x][1]=cnt;
}
int lca(int x,int y)
{
if (dep[x]>dep[y]) swap(x,y);
int k=19,t=1<<19,d=dep[y]-dep[x];
while (d)
{
if (d>=t) d-=t,y=f[y][k];
t/=2,k--;
}
if (x==y) return x;
k=19;
while (k>=0)
{
if (f[x][k]!=f[y][k])
{
x=f[x][k];
y=f[y][k];
}
k--;
}
return f[x][0];
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
cnt=0;
dfs(1,0);
for (int j=1;j<20;j++)
{
for (int i=1;i<=n;i++) f[i][j]=f[f[i][j-1]][j-1];
}
for (int i=1;i<=n;i++) ins(1,1,n,dfn[i][0],dfn[i][0],dep[i]);
for (int i=1;i<=m;i++)
{
scanf("%d",&op);
if (op==1)
{
scanf("%d",&x);
access(x);
}
if (op==2)
{
scanf("%d%d",&x,&y);
int d=lca(x,y);
int ans=getmax(1,1,n,dfn[x][0],dfn[x][0]);
ans+=getmax(1,1,n,dfn[y][0],dfn[y][0]);
ans-=2*getmax(1,1,n,dfn[d][0],dfn[d][0]);
ans++;
printf("%d\n",ans);
}
if (op==3)
{
scanf("%d",&x);
printf("%d\n",getmax(1,1,n,dfn[x][0],dfn[x][1]));
}
}
}